| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Fault injection example. | |||||
| Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. | |||||
| Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. | |||||
| File structure: | |||||
| --cifar10-batches-bin | |||||
| --train | |||||
| --data_batch_1.bin | |||||
| --data_batch_2.bin | |||||
| --data_batch_3.bin | |||||
| --data_batch_4.bin | |||||
| --data_batch_5.bin | |||||
| --test | |||||
| --test_batch.bin | |||||
| Please extract and restructure the file as shown above. | |||||
| """ | |||||
| import argparse | |||||
| from mindspore import Model, context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
| from examples.common.networks.vgg.vgg import vgg16 | |||||
| from examples.common.networks.resnet.resnet import resnet50 | |||||
| from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset | |||||
| parser = argparse.ArgumentParser(description='layer_states') | |||||
| parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) | |||||
| parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) | |||||
| parser.add_argument('--device_id', type=int, default=0) | |||||
| args = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||||
| test_flag = args.model | |||||
| if test_flag == 'lenet': | |||||
| # load data | |||||
| DATA_FILE = '../common/dataset/MNIST_Data/test' | |||||
| ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) | |||||
| net = LeNet5() | |||||
| elif test_flag == 'vgg16': | |||||
| from examples.common.networks.vgg.config import cifar_cfg as cfg | |||||
| DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||||
| ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' | |||||
| ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||||
| net = vgg16(10, cfg, 'test') | |||||
| elif test_flag == 'resnet50': | |||||
| DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||||
| ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' | |||||
| ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||||
| net = resnet50(10) | |||||
| else: | |||||
| exit() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| # Initialization | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1, 2, 3] | |||||
| # Fault injection | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| results = fi.kick_off() | |||||
| result_summary = fi.metrics() | |||||
| # print result | |||||
| for result in results: | |||||
| print(result) | |||||
| for result in result_summary: | |||||
| print(result) | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Reliability methods of MindArmour | |||||
| """ | |||||
| from .model_fault_injection.fault_injection import FaultInjector | |||||
| __all__ = ['FaultInjector'] | |||||
| @@ -0,0 +1,169 @@ | |||||
| # Demos of model fault injection | |||||
| ## Introduction | |||||
| This is a demo of fault injection for Mindspore applications written in Python. | |||||
| ## Preparation | |||||
| For the demo, we should prepare both datasets and pre-train models | |||||
| ### Dateset | |||||
| For example: | |||||
| `MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows | |||||
| ```test | |||||
| File structure: | |||||
| - data_path | |||||
| - train | |||||
| - train-images-idx3-ubyte | |||||
| - train-labels-idx1-ubyte | |||||
| - test | |||||
| - t10k-images-idx3-ubyte | |||||
| - t10k-labels-idx1-ubyte | |||||
| ``` | |||||
| `CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows | |||||
| ```test | |||||
| File structure: | |||||
| - data_path | |||||
| - train | |||||
| - data_batch_1.bin | |||||
| - data_batch_2.bin | |||||
| - data_batch_3.bin | |||||
| - data_batch_4.bin | |||||
| - data_batch_5.bin | |||||
| - test | |||||
| - test_batch.bin | |||||
| ``` | |||||
| ### CheckPoint file | |||||
| Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint | |||||
| ## Configuration | |||||
| There are five parameters need to set up. | |||||
| ```python | |||||
| DATA_FILE = '../common/dataset/MNIST_Data/test' | |||||
| ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt' | |||||
| ... | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1, 2, 3] | |||||
| ``` | |||||
| `DATA_FILE` is the directory where you store the data. | |||||
| `ckpt_path` is the directory where you store the checkpoint file. | |||||
| `fi_type` : | |||||
| Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss` | |||||
| bitflips_random: Bits are flipped randomly in the chosen value. | |||||
| bitflips_designated: Specified bit is flipped in the chosen value. | |||||
| random: The chosen value are replaced with random value in the range [-1, 1] | |||||
| zeros: The chosen value are replaced with zero. | |||||
| nan: The chosen value are replaced with NaN. | |||||
| inf: The chosen value are replaced with Inf. | |||||
| anti_activation: Changing the sign of the chosen value. | |||||
| precision_loss: Round the chosen value to 1 decimal place | |||||
| `fi_mode` : | |||||
| There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`. | |||||
| `fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100% | |||||
| ### Example configuration | |||||
| Sample 1: | |||||
| ```python | |||||
| fi_type = ['bitflips_random', 'random', 'zeros', 'inf'] | |||||
| fi_mode = ['single_layer'] | |||||
| fi_size = [1] | |||||
| ``` | |||||
| Sample 2: | |||||
| ```python | |||||
| fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1, 2] | |||||
| ``` | |||||
| ## Usage | |||||
| Run the test to observe the fault injection. For example: | |||||
| ```bash | |||||
| #!/bin/bash | |||||
| cd examples/reliability/ | |||||
| python model_fault_injection.py --device_target GPU --device_id 2 --model lenet | |||||
| ``` | |||||
| `device_target` | |||||
| `model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model. | |||||
| ## Result | |||||
| Finally, there are three kinds of result will be return. | |||||
| Sample: | |||||
| ```test | |||||
| original_acc:0.979768 | |||||
| type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||||
| type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||||
| ... | |||||
| type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||||
| type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||||
| single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||||
| single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||||
| all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||||
| all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601 | |||||
| ``` | |||||
| ### Original_acc | |||||
| The original accuracy of model: | |||||
| ```test | |||||
| original_acc:0.979768 | |||||
| ``` | |||||
| ### Specific result of each input parameter | |||||
| Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`. | |||||
| ```test | |||||
| type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||||
| type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||||
| ... | |||||
| type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||||
| type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||||
| ``` | |||||
| ### Summary of mode | |||||
| Summary of `single_layer` or `all_layer`. | |||||
| ```test | |||||
| single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||||
| single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||||
| all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||||
| all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601 | |||||
| ``` | |||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| This module provides model fault injection to evaluate the reliability of given model. | |||||
| """ | |||||
| from .fault_injection import FaultInjector | |||||
| from .fault_type import FaultType | |||||
| __all__ = ['FaultInjector', 'FaultType'] | |||||
| @@ -0,0 +1,224 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Fault injection module | |||||
| """ | |||||
| import random | |||||
| import numpy as np | |||||
| import mindspore | |||||
| from mindspore import ops, Tensor | |||||
| from mindarmour.reliability.model_fault_injection.fault_type import FaultType | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.utils._check_param import check_int_positive | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'FaultInjector' | |||||
| class FaultInjector: | |||||
| """ | |||||
| Fault injection for deep neural networks and evaluate performance. | |||||
| Args: | |||||
| model (Model): The model need to be evaluated. | |||||
| data (Dataset): The data for testing. The evaluation is base on this data. | |||||
| fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), | |||||
| bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. | |||||
| fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. | |||||
| fi_size (list): The number of fault injection.It mean that how many values need to be injected. | |||||
| Examples: | |||||
| >>> net = Net() | |||||
| >>> model = Model(net) | |||||
| >>> ds_eval = create_dataloader() | |||||
| >>> fi_type = ['bitflips_random', 'zeros'] | |||||
| >>> fi_mode = ['single_layer', 'all_layer'] | |||||
| >>> fi_size = [1, 2, 3] | |||||
| >>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) | |||||
| >>> fi.kick_off() | |||||
| """ | |||||
| def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): | |||||
| """FaultInjector initiated.""" | |||||
| self.running_list = [] | |||||
| self._init_running_list(fi_type, fi_mode, fi_size) | |||||
| self.model = model | |||||
| self.data = data | |||||
| self._fault_type = FaultType() | |||||
| self._check_param() | |||||
| self.result_list = [] | |||||
| self.original_acc = 0 | |||||
| self.original_parameter = {} | |||||
| self.argmax = ops.Argmax() | |||||
| self._reducesum = ops.ReduceSum(keep_dims=False) | |||||
| self._frozen() | |||||
| def _check_param(self): | |||||
| """Check input parameters.""" | |||||
| attr = self._fault_type.__dir__() | |||||
| if not isinstance(self.data, mindspore.dataset.Dataset): | |||||
| msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| _ = check_int_positive('dataset_size', self.data.get_dataset_size()) | |||||
| if not isinstance(self.model, mindspore.Model): | |||||
| msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| for param in self.running_list: | |||||
| if param['fi_type'] not in attr: | |||||
| msg = "'Undefined fault type', got {}.".format(param['fi_type']) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise AttributeError(msg) | |||||
| if param['fi_mode'] not in ['single_layer', 'all_layer']: | |||||
| msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode']) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| _ = check_int_positive('fi_size', param['fi_size']) | |||||
| def _init_running_list(self, type_, mode_, size_): | |||||
| """Initiate fault injection parameters of this evaluation.""" | |||||
| if type_ is None: | |||||
| type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', | |||||
| 'anti_activation', 'precision_loss'] | |||||
| if mode_ is None: | |||||
| mode_ = ['single_layer', 'all_layer'] | |||||
| if size_ is None: | |||||
| size_ = list(range(1, 4)) | |||||
| for i in type_: | |||||
| i = i if i.startswith('_') else '_' + i | |||||
| for j in mode_: | |||||
| for k in size_: | |||||
| dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} | |||||
| self.running_list.append(dict_) | |||||
| def _frozen(self): | |||||
| """Store original parameters of model.""" | |||||
| trainable_param = self.model.predict_network.trainable_params() | |||||
| for param in trainable_param: | |||||
| np_param = param.asnumpy().copy() | |||||
| bytes_ = np_param.tobytes() | |||||
| self.original_parameter[param.name] = {} | |||||
| self.original_parameter[param.name]['datatype'] = np_param.dtype | |||||
| self.original_parameter[param.name]['shape'] = np_param.shape | |||||
| self.original_parameter[param.name]['data'] = bytes_.hex() | |||||
| def _reset_model(self): | |||||
| """Reset model with original parameters.""" | |||||
| for weight in self.model.predict_network.trainable_params(): | |||||
| name = weight.name | |||||
| if name in self.original_parameter.keys(): | |||||
| bytes_w = bytes.fromhex(self.original_parameter[name]['data']) | |||||
| datatype_w = self.original_parameter[name]['datatype'] | |||||
| shape_w = self.original_parameter[name]['shape'] | |||||
| np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w) | |||||
| weight.assign_value(Tensor.from_numpy(np_w)) | |||||
| else: | |||||
| msg = "Layer name not matched, got {}.".format(name) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise KeyError(msg) | |||||
| def kick_off(self): | |||||
| """ | |||||
| Startup and return final results. | |||||
| Returns: | |||||
| list, the result of fault injection. | |||||
| """ | |||||
| result_list = [] | |||||
| for i in range(-1, len(self.running_list)): | |||||
| arg = self.running_list[i] | |||||
| total = 0 | |||||
| correct = 0 | |||||
| for data in self.data.create_dict_iterator(): | |||||
| batch = data['image'] | |||||
| label = data['label'] | |||||
| if i != -1: | |||||
| self._reset_model() | |||||
| self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) | |||||
| output = self.model.predict(batch) | |||||
| predict = self.argmax(output) | |||||
| mask = predict == label | |||||
| total += predict.size | |||||
| correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() | |||||
| acc = correct / total | |||||
| if i == -1: | |||||
| self.original_acc = acc | |||||
| result_list.append({'original_acc': self.original_acc}) | |||||
| else: | |||||
| result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], | |||||
| 'acc': acc, 'SDC': self.original_acc - acc}) | |||||
| self.data.reset() | |||||
| self._reset_model() | |||||
| self.result_list = result_list | |||||
| return result_list | |||||
| def metrics(self): | |||||
| """metrics of final result.""" | |||||
| result_summary = [] | |||||
| single_layer_acc = [] | |||||
| single_layer_sdc = [] | |||||
| all_layer_acc = [] | |||||
| all_layer_sdc = [] | |||||
| for result in self.result_list: | |||||
| if 'mode' in result.keys(): | |||||
| if result['mode'] == 'single_layer': | |||||
| single_layer_acc.append(float(result['acc'])) | |||||
| single_layer_sdc.append(float(result['SDC'])) | |||||
| else: | |||||
| all_layer_acc.append(float(result['acc'])) | |||||
| all_layer_sdc.append(float(result['SDC'])) | |||||
| s_acc = np.array(single_layer_acc) | |||||
| s_sdc = np.array(single_layer_sdc) | |||||
| a_acc = np.array(all_layer_acc) | |||||
| a_sdc = np.array(all_layer_sdc) | |||||
| if single_layer_acc: | |||||
| result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f' | |||||
| % (np.mean(s_acc), np.max(s_acc), np.min(s_acc))) | |||||
| result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f' | |||||
| % (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc))) | |||||
| if all_layer_acc: | |||||
| result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f' | |||||
| % (np.mean(a_acc), np.max(a_acc), np.min(a_acc))) | |||||
| result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f' | |||||
| % (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc))) | |||||
| return result_summary | |||||
| def _layer_states(self, fi_type, fi_mode, fi_size): | |||||
| """FI in layer states.""" | |||||
| # Choose a random layer for injection | |||||
| if fi_mode == "single_layer": | |||||
| # Single layer fault injection mode | |||||
| random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)] | |||||
| elif fi_mode == "all_layer": | |||||
| # Multiple layer fault injection mode | |||||
| random_num = list(range(len(self.model.predict_network.trainable_params()) - 1)) | |||||
| else: | |||||
| msg = 'undefined fi_mode {}'.format(fi_mode) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| for n in random_num: | |||||
| # Get layer states info | |||||
| w = self.model.predict_network.trainable_params()[n] | |||||
| w_np = w.asnumpy().copy() | |||||
| elem_shape = w_np.shape | |||||
| w_np = w_np.reshape(-1) | |||||
| # fault inject | |||||
| new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size) | |||||
| # Reshape into original dimensions and store the faulty tensor | |||||
| new_w_np = np.reshape(new_w_np, elem_shape) | |||||
| w.set_data(Tensor.from_numpy(new_w_np)) | |||||
| @@ -0,0 +1,219 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Fault type module | |||||
| """ | |||||
| import math | |||||
| import random | |||||
| from struct import pack, unpack | |||||
| import numpy as np | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'FaultType' | |||||
| class FaultType: | |||||
| """Implementation of specified fault type.""" | |||||
| @staticmethod | |||||
| def _bitflip(value, pos): | |||||
| """ | |||||
| Implement of bitflip. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| pos (list): The index of flip position. | |||||
| Returns: | |||||
| numpy.ndarray, bitflip data. | |||||
| """ | |||||
| bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||||
| value_format = 'B' * int(int(bits) / 8) | |||||
| value_bytes = value.tobytes() | |||||
| bytes_ = list(unpack(value_format, value_bytes)) | |||||
| for p in pos: | |||||
| [q, r] = divmod(p, 8) | |||||
| bytes_[q] ^= 1 << r | |||||
| new_value_bytes = pack(value_format, *bytes_) | |||||
| new_value = np.frombuffer(new_value_bytes, value.dtype) | |||||
| return new_value[0] | |||||
| def _fault_inject(self, value, fi_type, fi_size): | |||||
| """ | |||||
| Inject the specified fault into the randomly chosen values. | |||||
| For zeros, anti_activation and precision_loss, fi_size is the percentage of | |||||
| total number. And the others fault, fi_size is the exact number of values to | |||||
| be injected. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_type (str): Fault type. | |||||
| fi_size (int): The number of fault injection. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| num = value.size | |||||
| if fi_type in ['zeros', 'anti_activation', 'precision_loss']: | |||||
| change_size = (fi_size * num) / 100 | |||||
| change_size = math.floor(change_size) | |||||
| else: | |||||
| change_size = fi_size | |||||
| if change_size > num: | |||||
| change_size = num | |||||
| # Choose the indices for FI | |||||
| ind = random.sample(range(num), change_size) | |||||
| # got specified fault type | |||||
| try: | |||||
| func = getattr(self, fi_type) | |||||
| value = func(value, ind) | |||||
| return value | |||||
| except AttributeError: | |||||
| msg = "'Undefined fault type', got {}.".format(fi_type) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise AttributeError(msg) | |||||
| def _bitflips_random(self, value, fi_indices): | |||||
| """ | |||||
| Flip bit randomly for specified value. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| for item in fi_indices: | |||||
| val = value[item] | |||||
| pos = random.sample(range(int(str(val.dtype)[-2:])), | |||||
| 1 if np.random.random() < 0.618 else 2) | |||||
| val_new = self._bitflip(val, pos) | |||||
| value[item] = val_new | |||||
| return value | |||||
| def _bitflips_designated(self, value, fi_indices): | |||||
| """ | |||||
| Flip the key bit for specified value. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| for item in fi_indices: | |||||
| val = value[item] | |||||
| # uint8 uint16 uint32 uint64 | |||||
| bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||||
| if 'uint' in str(val.dtype): | |||||
| pos = int(bits) - 1 | |||||
| # int8 int16 int32 int64 float16 float32 float64 | |||||
| else: | |||||
| pos = int(bits) - 2 | |||||
| val_new = self._bitflip(val, [pos]) | |||||
| value[item] = val_new | |||||
| return value | |||||
| @staticmethod | |||||
| def _random(value, fi_indices): | |||||
| """ | |||||
| Reset specified value randomly, range from -1 to 1. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| for item in fi_indices: | |||||
| value[item] = np.random.random() * 2 - 1 | |||||
| return value | |||||
| @staticmethod | |||||
| def _zeros(value, fi_indices): | |||||
| """ | |||||
| Set specified value into zeros. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| value[fi_indices] = 0. | |||||
| return value | |||||
| @staticmethod | |||||
| def _nan(value, fi_indices): | |||||
| """ | |||||
| Set specified value into nan. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| try: | |||||
| value[fi_indices] = np.nan | |||||
| return value | |||||
| except ValueError: | |||||
| return value | |||||
| @staticmethod | |||||
| def _inf(value, fi_indices): | |||||
| """ | |||||
| Set specified value into inf | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| try: | |||||
| value[fi_indices] = np.inf | |||||
| return value | |||||
| except OverflowError: | |||||
| return value | |||||
| @staticmethod | |||||
| def _anti_activation(value, fi_indices): | |||||
| """ | |||||
| Minus specified value. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| value[fi_indices] = -value[fi_indices] | |||||
| return value | |||||
| @staticmethod | |||||
| def _precision_loss(value, fi_indices): | |||||
| """ | |||||
| Round specified value, round to 1 decimal place. | |||||
| Args: | |||||
| value (numpy.ndarray): Input data. | |||||
| fi_indices (list): The index of injected data. | |||||
| Returns: | |||||
| numpy.ndarray, data after fault injection. | |||||
| """ | |||||
| value[fi_indices] = np.around(value[fi_indices], decimals=1) | |||||
| return value | |||||
| @@ -0,0 +1,240 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Test for fault injection. | |||||
| """ | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import Model | |||||
| import mindspore.dataset as ds | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||||
| from tests.ut.python.utils.mock_net import Net | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Fault injection test' | |||||
| LOGGER.set_level('INFO') | |||||
| def dataset_generator(): | |||||
| """mock training data.""" | |||||
| batch_size = 32 | |||||
| batches = 128 | |||||
| data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||||
| np.float32) | |||||
| label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||||
| for i in range(batches): | |||||
| yield data[i*batch_size:(i + 1)*batch_size],\ | |||||
| label[i*batch_size:(i + 1)*batch_size] | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_fault_injector(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Run kick_off and metrics successfully | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1] | |||||
| # Fault injection | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_wrong_model(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Throw TypeError exception | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1] | |||||
| # Fault injection | |||||
| with pytest.raises(TypeError) as exc_info: | |||||
| fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| assert exc_info.type is TypeError | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_wrong_data(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Throw TypeError exception | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1] | |||||
| # Fault injection | |||||
| with pytest.raises(TypeError) as exc_info: | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| assert exc_info.type is TypeError | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_wrong_fi_type(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Throw AttributeError exception | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
| fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [1] | |||||
| # Fault injection | |||||
| with pytest.raises(AttributeError) as exc_info: | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| assert exc_info.type is AttributeError | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_wrong_fi_mode(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Throw ValueError exception | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer_tail', 'all_layer'] | |||||
| fi_size = [1] | |||||
| # Fault injection | |||||
| with pytest.raises(ValueError) as exc_info: | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| assert exc_info.type is ValueError | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_wrong_fi_size(): | |||||
| """ | |||||
| Feature: Fault injector | |||||
| Description: Test fault injector | |||||
| Expectation: Throw ValueError exception | |||||
| """ | |||||
| # load model | |||||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = Net() | |||||
| param_dict = load_checkpoint(ckpt_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||||
| fi_mode = ['single_layer', 'all_layer'] | |||||
| fi_size = [-1] | |||||
| # Fault injection | |||||
| with pytest.raises(ValueError) as exc_info: | |||||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||||
| _ = fi.kick_off() | |||||
| _ = fi.metrics() | |||||
| assert exc_info.type is ValueError | |||||