Merge pull request !279 from ye12121/mastertags/v1.6.0
@@ -0,0 +1,92 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Fault injection example. | |||||
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. | |||||
Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. | |||||
File structure: | |||||
--cifar10-batches-bin | |||||
--train | |||||
--data_batch_1.bin | |||||
--data_batch_2.bin | |||||
--data_batch_3.bin | |||||
--data_batch_4.bin | |||||
--data_batch_5.bin | |||||
--test | |||||
--test_batch.bin | |||||
Please extract and restructure the file as shown above. | |||||
""" | |||||
import argparse | |||||
from mindspore import Model, context | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
from examples.common.networks.vgg.vgg import vgg16 | |||||
from examples.common.networks.resnet.resnet import resnet50 | |||||
from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset | |||||
parser = argparse.ArgumentParser(description='layer_states') | |||||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) | |||||
parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) | |||||
parser.add_argument('--device_id', type=int, default=0) | |||||
args = parser.parse_args() | |||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||||
test_flag = args.model | |||||
if test_flag == 'lenet': | |||||
# load data | |||||
DATA_FILE = '../common/dataset/MNIST_Data/test' | |||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) | |||||
net = LeNet5() | |||||
elif test_flag == 'vgg16': | |||||
from examples.common.networks.vgg.config import cifar_cfg as cfg | |||||
DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||||
ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' | |||||
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||||
net = vgg16(10, cfg, 'test') | |||||
elif test_flag == 'resnet50': | |||||
DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||||
ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' | |||||
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||||
net = resnet50(10) | |||||
else: | |||||
exit() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
# Initialization | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1, 2, 3] | |||||
# Fault injection | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
results = fi.kick_off() | |||||
result_summary = fi.metrics() | |||||
# print result | |||||
for result in results: | |||||
print(result) | |||||
for result in result_summary: | |||||
print(result) |
@@ -0,0 +1,20 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Reliability methods of MindArmour | |||||
""" | |||||
from .model_fault_injection.fault_injection import FaultInjector | |||||
__all__ = ['FaultInjector'] |
@@ -0,0 +1,169 @@ | |||||
# Demos of model fault injection | |||||
## Introduction | |||||
This is a demo of fault injection for Mindspore applications written in Python. | |||||
## Preparation | |||||
For the demo, we should prepare both datasets and pre-train models | |||||
### Dateset | |||||
For example: | |||||
`MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows | |||||
```test | |||||
File structure: | |||||
- data_path | |||||
- train | |||||
- train-images-idx3-ubyte | |||||
- train-labels-idx1-ubyte | |||||
- test | |||||
- t10k-images-idx3-ubyte | |||||
- t10k-labels-idx1-ubyte | |||||
``` | |||||
`CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows | |||||
```test | |||||
File structure: | |||||
- data_path | |||||
- train | |||||
- data_batch_1.bin | |||||
- data_batch_2.bin | |||||
- data_batch_3.bin | |||||
- data_batch_4.bin | |||||
- data_batch_5.bin | |||||
- test | |||||
- test_batch.bin | |||||
``` | |||||
### CheckPoint file | |||||
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint | |||||
## Configuration | |||||
There are five parameters need to set up. | |||||
```python | |||||
DATA_FILE = '../common/dataset/MNIST_Data/test' | |||||
ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt' | |||||
... | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1, 2, 3] | |||||
``` | |||||
`DATA_FILE` is the directory where you store the data. | |||||
`ckpt_path` is the directory where you store the checkpoint file. | |||||
`fi_type` : | |||||
Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss` | |||||
bitflips_random: Bits are flipped randomly in the chosen value. | |||||
bitflips_designated: Specified bit is flipped in the chosen value. | |||||
random: The chosen value are replaced with random value in the range [-1, 1] | |||||
zeros: The chosen value are replaced with zero. | |||||
nan: The chosen value are replaced with NaN. | |||||
inf: The chosen value are replaced with Inf. | |||||
anti_activation: Changing the sign of the chosen value. | |||||
precision_loss: Round the chosen value to 1 decimal place | |||||
`fi_mode` : | |||||
There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`. | |||||
`fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100% | |||||
### Example configuration | |||||
Sample 1: | |||||
```python | |||||
fi_type = ['bitflips_random', 'random', 'zeros', 'inf'] | |||||
fi_mode = ['single_layer'] | |||||
fi_size = [1] | |||||
``` | |||||
Sample 2: | |||||
```python | |||||
fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1, 2] | |||||
``` | |||||
## Usage | |||||
Run the test to observe the fault injection. For example: | |||||
```bash | |||||
#!/bin/bash | |||||
cd examples/reliability/ | |||||
python model_fault_injection.py --device_target GPU --device_id 2 --model lenet | |||||
``` | |||||
`device_target` | |||||
`model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model. | |||||
## Result | |||||
Finally, there are three kinds of result will be return. | |||||
Sample: | |||||
```test | |||||
original_acc:0.979768 | |||||
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||||
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||||
... | |||||
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||||
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||||
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||||
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||||
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||||
all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601 | |||||
``` | |||||
### Original_acc | |||||
The original accuracy of model: | |||||
```test | |||||
original_acc:0.979768 | |||||
``` | |||||
### Specific result of each input parameter | |||||
Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`. | |||||
```test | |||||
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||||
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||||
... | |||||
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||||
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||||
``` | |||||
### Summary of mode | |||||
Summary of `single_layer` or `all_layer`. | |||||
```test | |||||
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||||
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||||
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||||
all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601 | |||||
``` |
@@ -0,0 +1,18 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
""" | |||||
This module provides model fault injection to evaluate the reliability of given model. | |||||
""" | |||||
from .fault_injection import FaultInjector | |||||
from .fault_type import FaultType | |||||
__all__ = ['FaultInjector', 'FaultType'] |
@@ -0,0 +1,224 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
""" | |||||
Fault injection module | |||||
""" | |||||
import random | |||||
import numpy as np | |||||
import mindspore | |||||
from mindspore import ops, Tensor | |||||
from mindarmour.reliability.model_fault_injection.fault_type import FaultType | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.utils._check_param import check_int_positive | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'FaultInjector' | |||||
class FaultInjector: | |||||
""" | |||||
Fault injection for deep neural networks and evaluate performance. | |||||
Args: | |||||
model (Model): The model need to be evaluated. | |||||
data (Dataset): The data for testing. The evaluation is base on this data. | |||||
fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), | |||||
bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. | |||||
fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. | |||||
fi_size (list): The number of fault injection.It mean that how many values need to be injected. | |||||
Examples: | |||||
>>> net = Net() | |||||
>>> model = Model(net) | |||||
>>> ds_eval = create_dataloader() | |||||
>>> fi_type = ['bitflips_random', 'zeros'] | |||||
>>> fi_mode = ['single_layer', 'all_layer'] | |||||
>>> fi_size = [1, 2, 3] | |||||
>>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) | |||||
>>> fi.kick_off() | |||||
""" | |||||
def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): | |||||
"""FaultInjector initiated.""" | |||||
self.running_list = [] | |||||
self._init_running_list(fi_type, fi_mode, fi_size) | |||||
self.model = model | |||||
self.data = data | |||||
self._fault_type = FaultType() | |||||
self._check_param() | |||||
self.result_list = [] | |||||
self.original_acc = 0 | |||||
self.original_parameter = {} | |||||
self.argmax = ops.Argmax() | |||||
self._reducesum = ops.ReduceSum(keep_dims=False) | |||||
self._frozen() | |||||
def _check_param(self): | |||||
"""Check input parameters.""" | |||||
attr = self._fault_type.__dir__() | |||||
if not isinstance(self.data, mindspore.dataset.Dataset): | |||||
msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
_ = check_int_positive('dataset_size', self.data.get_dataset_size()) | |||||
if not isinstance(self.model, mindspore.Model): | |||||
msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
for param in self.running_list: | |||||
if param['fi_type'] not in attr: | |||||
msg = "'Undefined fault type', got {}.".format(param['fi_type']) | |||||
LOGGER.error(TAG, msg) | |||||
raise AttributeError(msg) | |||||
if param['fi_mode'] not in ['single_layer', 'all_layer']: | |||||
msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode']) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
_ = check_int_positive('fi_size', param['fi_size']) | |||||
def _init_running_list(self, type_, mode_, size_): | |||||
"""Initiate fault injection parameters of this evaluation.""" | |||||
if type_ is None: | |||||
type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', | |||||
'anti_activation', 'precision_loss'] | |||||
if mode_ is None: | |||||
mode_ = ['single_layer', 'all_layer'] | |||||
if size_ is None: | |||||
size_ = list(range(1, 4)) | |||||
for i in type_: | |||||
i = i if i.startswith('_') else '_' + i | |||||
for j in mode_: | |||||
for k in size_: | |||||
dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} | |||||
self.running_list.append(dict_) | |||||
def _frozen(self): | |||||
"""Store original parameters of model.""" | |||||
trainable_param = self.model.predict_network.trainable_params() | |||||
for param in trainable_param: | |||||
np_param = param.asnumpy().copy() | |||||
bytes_ = np_param.tobytes() | |||||
self.original_parameter[param.name] = {} | |||||
self.original_parameter[param.name]['datatype'] = np_param.dtype | |||||
self.original_parameter[param.name]['shape'] = np_param.shape | |||||
self.original_parameter[param.name]['data'] = bytes_.hex() | |||||
def _reset_model(self): | |||||
"""Reset model with original parameters.""" | |||||
for weight in self.model.predict_network.trainable_params(): | |||||
name = weight.name | |||||
if name in self.original_parameter.keys(): | |||||
bytes_w = bytes.fromhex(self.original_parameter[name]['data']) | |||||
datatype_w = self.original_parameter[name]['datatype'] | |||||
shape_w = self.original_parameter[name]['shape'] | |||||
np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w) | |||||
weight.assign_value(Tensor.from_numpy(np_w)) | |||||
else: | |||||
msg = "Layer name not matched, got {}.".format(name) | |||||
LOGGER.error(TAG, msg) | |||||
raise KeyError(msg) | |||||
def kick_off(self): | |||||
""" | |||||
Startup and return final results. | |||||
Returns: | |||||
list, the result of fault injection. | |||||
""" | |||||
result_list = [] | |||||
for i in range(-1, len(self.running_list)): | |||||
arg = self.running_list[i] | |||||
total = 0 | |||||
correct = 0 | |||||
for data in self.data.create_dict_iterator(): | |||||
batch = data['image'] | |||||
label = data['label'] | |||||
if i != -1: | |||||
self._reset_model() | |||||
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) | |||||
output = self.model.predict(batch) | |||||
predict = self.argmax(output) | |||||
mask = predict == label | |||||
total += predict.size | |||||
correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() | |||||
acc = correct / total | |||||
if i == -1: | |||||
self.original_acc = acc | |||||
result_list.append({'original_acc': self.original_acc}) | |||||
else: | |||||
result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], | |||||
'acc': acc, 'SDC': self.original_acc - acc}) | |||||
self.data.reset() | |||||
self._reset_model() | |||||
self.result_list = result_list | |||||
return result_list | |||||
def metrics(self): | |||||
"""metrics of final result.""" | |||||
result_summary = [] | |||||
single_layer_acc = [] | |||||
single_layer_sdc = [] | |||||
all_layer_acc = [] | |||||
all_layer_sdc = [] | |||||
for result in self.result_list: | |||||
if 'mode' in result.keys(): | |||||
if result['mode'] == 'single_layer': | |||||
single_layer_acc.append(float(result['acc'])) | |||||
single_layer_sdc.append(float(result['SDC'])) | |||||
else: | |||||
all_layer_acc.append(float(result['acc'])) | |||||
all_layer_sdc.append(float(result['SDC'])) | |||||
s_acc = np.array(single_layer_acc) | |||||
s_sdc = np.array(single_layer_sdc) | |||||
a_acc = np.array(all_layer_acc) | |||||
a_sdc = np.array(all_layer_sdc) | |||||
if single_layer_acc: | |||||
result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f' | |||||
% (np.mean(s_acc), np.max(s_acc), np.min(s_acc))) | |||||
result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f' | |||||
% (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc))) | |||||
if all_layer_acc: | |||||
result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f' | |||||
% (np.mean(a_acc), np.max(a_acc), np.min(a_acc))) | |||||
result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f' | |||||
% (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc))) | |||||
return result_summary | |||||
def _layer_states(self, fi_type, fi_mode, fi_size): | |||||
"""FI in layer states.""" | |||||
# Choose a random layer for injection | |||||
if fi_mode == "single_layer": | |||||
# Single layer fault injection mode | |||||
random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)] | |||||
elif fi_mode == "all_layer": | |||||
# Multiple layer fault injection mode | |||||
random_num = list(range(len(self.model.predict_network.trainable_params()) - 1)) | |||||
else: | |||||
msg = 'undefined fi_mode {}'.format(fi_mode) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
for n in random_num: | |||||
# Get layer states info | |||||
w = self.model.predict_network.trainable_params()[n] | |||||
w_np = w.asnumpy().copy() | |||||
elem_shape = w_np.shape | |||||
w_np = w_np.reshape(-1) | |||||
# fault inject | |||||
new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size) | |||||
# Reshape into original dimensions and store the faulty tensor | |||||
new_w_np = np.reshape(new_w_np, elem_shape) | |||||
w.set_data(Tensor.from_numpy(new_w_np)) |
@@ -0,0 +1,219 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
""" | |||||
Fault type module | |||||
""" | |||||
import math | |||||
import random | |||||
from struct import pack, unpack | |||||
import numpy as np | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'FaultType' | |||||
class FaultType: | |||||
"""Implementation of specified fault type.""" | |||||
@staticmethod | |||||
def _bitflip(value, pos): | |||||
""" | |||||
Implement of bitflip. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
pos (list): The index of flip position. | |||||
Returns: | |||||
numpy.ndarray, bitflip data. | |||||
""" | |||||
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||||
value_format = 'B' * int(int(bits) / 8) | |||||
value_bytes = value.tobytes() | |||||
bytes_ = list(unpack(value_format, value_bytes)) | |||||
for p in pos: | |||||
[q, r] = divmod(p, 8) | |||||
bytes_[q] ^= 1 << r | |||||
new_value_bytes = pack(value_format, *bytes_) | |||||
new_value = np.frombuffer(new_value_bytes, value.dtype) | |||||
return new_value[0] | |||||
def _fault_inject(self, value, fi_type, fi_size): | |||||
""" | |||||
Inject the specified fault into the randomly chosen values. | |||||
For zeros, anti_activation and precision_loss, fi_size is the percentage of | |||||
total number. And the others fault, fi_size is the exact number of values to | |||||
be injected. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_type (str): Fault type. | |||||
fi_size (int): The number of fault injection. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
num = value.size | |||||
if fi_type in ['zeros', 'anti_activation', 'precision_loss']: | |||||
change_size = (fi_size * num) / 100 | |||||
change_size = math.floor(change_size) | |||||
else: | |||||
change_size = fi_size | |||||
if change_size > num: | |||||
change_size = num | |||||
# Choose the indices for FI | |||||
ind = random.sample(range(num), change_size) | |||||
# got specified fault type | |||||
try: | |||||
func = getattr(self, fi_type) | |||||
value = func(value, ind) | |||||
return value | |||||
except AttributeError: | |||||
msg = "'Undefined fault type', got {}.".format(fi_type) | |||||
LOGGER.error(TAG, msg) | |||||
raise AttributeError(msg) | |||||
def _bitflips_random(self, value, fi_indices): | |||||
""" | |||||
Flip bit randomly for specified value. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
for item in fi_indices: | |||||
val = value[item] | |||||
pos = random.sample(range(int(str(val.dtype)[-2:])), | |||||
1 if np.random.random() < 0.618 else 2) | |||||
val_new = self._bitflip(val, pos) | |||||
value[item] = val_new | |||||
return value | |||||
def _bitflips_designated(self, value, fi_indices): | |||||
""" | |||||
Flip the key bit for specified value. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
for item in fi_indices: | |||||
val = value[item] | |||||
# uint8 uint16 uint32 uint64 | |||||
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||||
if 'uint' in str(val.dtype): | |||||
pos = int(bits) - 1 | |||||
# int8 int16 int32 int64 float16 float32 float64 | |||||
else: | |||||
pos = int(bits) - 2 | |||||
val_new = self._bitflip(val, [pos]) | |||||
value[item] = val_new | |||||
return value | |||||
@staticmethod | |||||
def _random(value, fi_indices): | |||||
""" | |||||
Reset specified value randomly, range from -1 to 1. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
for item in fi_indices: | |||||
value[item] = np.random.random() * 2 - 1 | |||||
return value | |||||
@staticmethod | |||||
def _zeros(value, fi_indices): | |||||
""" | |||||
Set specified value into zeros. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
value[fi_indices] = 0. | |||||
return value | |||||
@staticmethod | |||||
def _nan(value, fi_indices): | |||||
""" | |||||
Set specified value into nan. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
try: | |||||
value[fi_indices] = np.nan | |||||
return value | |||||
except ValueError: | |||||
return value | |||||
@staticmethod | |||||
def _inf(value, fi_indices): | |||||
""" | |||||
Set specified value into inf | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
try: | |||||
value[fi_indices] = np.inf | |||||
return value | |||||
except OverflowError: | |||||
return value | |||||
@staticmethod | |||||
def _anti_activation(value, fi_indices): | |||||
""" | |||||
Minus specified value. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
value[fi_indices] = -value[fi_indices] | |||||
return value | |||||
@staticmethod | |||||
def _precision_loss(value, fi_indices): | |||||
""" | |||||
Round specified value, round to 1 decimal place. | |||||
Args: | |||||
value (numpy.ndarray): Input data. | |||||
fi_indices (list): The index of injected data. | |||||
Returns: | |||||
numpy.ndarray, data after fault injection. | |||||
""" | |||||
value[fi_indices] = np.around(value[fi_indices], decimals=1) | |||||
return value |
@@ -0,0 +1,240 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Test for fault injection. | |||||
""" | |||||
import pytest | |||||
import numpy as np | |||||
from mindspore import Model | |||||
import mindspore.dataset as ds | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||||
from tests.ut.python.utils.mock_net import Net | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Fault injection test' | |||||
LOGGER.set_level('INFO') | |||||
def dataset_generator(): | |||||
"""mock training data.""" | |||||
batch_size = 32 | |||||
batches = 128 | |||||
data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||||
np.float32) | |||||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||||
for i in range(batches): | |||||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||||
label[i*batch_size:(i + 1)*batch_size] | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_fault_injector(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Run kick_off and metrics successfully | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1] | |||||
# Fault injection | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_wrong_model(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Throw TypeError exception | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1] | |||||
# Fault injection | |||||
with pytest.raises(TypeError) as exc_info: | |||||
fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
assert exc_info.type is TypeError | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_wrong_data(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Throw TypeError exception | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1] | |||||
# Fault injection | |||||
with pytest.raises(TypeError) as exc_info: | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
assert exc_info.type is TypeError | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_wrong_fi_type(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Throw AttributeError exception | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [1] | |||||
# Fault injection | |||||
with pytest.raises(AttributeError) as exc_info: | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
assert exc_info.type is AttributeError | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_wrong_fi_mode(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Throw ValueError exception | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer_tail', 'all_layer'] | |||||
fi_size = [1] | |||||
# Fault injection | |||||
with pytest.raises(ValueError) as exc_info: | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
assert exc_info.type is ValueError | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_wrong_fi_size(): | |||||
""" | |||||
Feature: Fault injector | |||||
Description: Test fault injector | |||||
Expectation: Throw ValueError exception | |||||
""" | |||||
# load model | |||||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = Net() | |||||
param_dict = load_checkpoint(ckpt_path) | |||||
load_param_into_net(net, param_dict) | |||||
model = Model(net) | |||||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
fi_mode = ['single_layer', 'all_layer'] | |||||
fi_size = [-1] | |||||
# Fault injection | |||||
with pytest.raises(ValueError) as exc_info: | |||||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
_ = fi.kick_off() | |||||
_ = fi.metrics() | |||||
assert exc_info.type is ValueError |