| @@ -0,0 +1,92 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Fault injection example. | |||
| Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. | |||
| Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. | |||
| File structure: | |||
| --cifar10-batches-bin | |||
| --train | |||
| --data_batch_1.bin | |||
| --data_batch_2.bin | |||
| --data_batch_3.bin | |||
| --data_batch_4.bin | |||
| --data_batch_5.bin | |||
| --test | |||
| --test_batch.bin | |||
| Please extract and restructure the file as shown above. | |||
| """ | |||
| import argparse | |||
| from mindspore import Model, context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||
| from examples.common.networks.vgg.vgg import vgg16 | |||
| from examples.common.networks.resnet.resnet import resnet50 | |||
| from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset | |||
| parser = argparse.ArgumentParser(description='layer_states') | |||
| parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) | |||
| parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) | |||
| parser.add_argument('--device_id', type=int, default=0) | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| test_flag = args.model | |||
| if test_flag == 'lenet': | |||
| # load data | |||
| DATA_FILE = '../common/dataset/MNIST_Data/test' | |||
| ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) | |||
| net = LeNet5() | |||
| elif test_flag == 'vgg16': | |||
| from examples.common.networks.vgg.config import cifar_cfg as cfg | |||
| DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||
| ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' | |||
| ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||
| net = vgg16(10, cfg, 'test') | |||
| elif test_flag == 'resnet50': | |||
| DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||
| ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' | |||
| ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||
| net = resnet50(10) | |||
| else: | |||
| exit() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| # Initialization | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1, 2, 3] | |||
| # Fault injection | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| results = fi.kick_off() | |||
| result_summary = fi.metrics() | |||
| # print result | |||
| for result in results: | |||
| print(result) | |||
| for result in result_summary: | |||
| print(result) | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Reliability methods of MindArmour | |||
| """ | |||
| from .model_fault_injection.fault_injection import FaultInjector | |||
| __all__ = ['FaultInjector'] | |||
| @@ -0,0 +1,169 @@ | |||
| # Demos of model fault injection | |||
| ## Introduction | |||
| This is a demo of fault injection for Mindspore applications written in Python. | |||
| ## Preparation | |||
| For the demo, we should prepare both datasets and pre-train models | |||
| ### Dateset | |||
| For example: | |||
| `MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows | |||
| ```test | |||
| File structure: | |||
| - data_path | |||
| - train | |||
| - train-images-idx3-ubyte | |||
| - train-labels-idx1-ubyte | |||
| - test | |||
| - t10k-images-idx3-ubyte | |||
| - t10k-labels-idx1-ubyte | |||
| ``` | |||
| `CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows | |||
| ```test | |||
| File structure: | |||
| - data_path | |||
| - train | |||
| - data_batch_1.bin | |||
| - data_batch_2.bin | |||
| - data_batch_3.bin | |||
| - data_batch_4.bin | |||
| - data_batch_5.bin | |||
| - test | |||
| - test_batch.bin | |||
| ``` | |||
| ### CheckPoint file | |||
| Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint | |||
| ## Configuration | |||
| There are five parameters need to set up. | |||
| ```python | |||
| DATA_FILE = '../common/dataset/MNIST_Data/test' | |||
| ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt' | |||
| ... | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1, 2, 3] | |||
| ``` | |||
| `DATA_FILE` is the directory where you store the data. | |||
| `ckpt_path` is the directory where you store the checkpoint file. | |||
| `fi_type` : | |||
| Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss` | |||
| bitflips_random: Bits are flipped randomly in the chosen value. | |||
| bitflips_designated: Specified bit is flipped in the chosen value. | |||
| random: The chosen value are replaced with random value in the range [-1, 1] | |||
| zeros: The chosen value are replaced with zero. | |||
| nan: The chosen value are replaced with NaN. | |||
| inf: The chosen value are replaced with Inf. | |||
| anti_activation: Changing the sign of the chosen value. | |||
| precision_loss: Round the chosen value to 1 decimal place | |||
| `fi_mode` : | |||
| There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`. | |||
| `fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100% | |||
| ### Example configuration | |||
| Sample 1: | |||
| ```python | |||
| fi_type = ['bitflips_random', 'random', 'zeros', 'inf'] | |||
| fi_mode = ['single_layer'] | |||
| fi_size = [1] | |||
| ``` | |||
| Sample 2: | |||
| ```python | |||
| fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1, 2] | |||
| ``` | |||
| ## Usage | |||
| Run the test to observe the fault injection. For example: | |||
| ```bash | |||
| #!/bin/bash | |||
| cd examples/reliability/ | |||
| python model_fault_injection.py --device_target GPU --device_id 2 --model lenet | |||
| ``` | |||
| `device_target` | |||
| `model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model. | |||
| ## Result | |||
| Finally, there are three kinds of result will be return. | |||
| Sample: | |||
| ```test | |||
| original_acc:0.979768 | |||
| type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||
| type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||
| ... | |||
| type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||
| type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||
| single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||
| single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||
| all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||
| all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601 | |||
| ``` | |||
| ### Original_acc | |||
| The original accuracy of model: | |||
| ```test | |||
| original_acc:0.979768 | |||
| ``` | |||
| ### Specific result of each input parameter | |||
| Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`. | |||
| ```test | |||
| type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||
| type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||
| ... | |||
| type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||
| type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||
| ``` | |||
| ### Summary of mode | |||
| Summary of `single_layer` or `all_layer`. | |||
| ```test | |||
| single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||
| single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||
| all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||
| all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601 | |||
| ``` | |||
| @@ -0,0 +1,18 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| This module provides model fault injection to evaluate the reliability of given model. | |||
| """ | |||
| from .fault_injection import FaultInjector | |||
| from .fault_type import FaultType | |||
| __all__ = ['FaultInjector', 'FaultType'] | |||
| @@ -0,0 +1,224 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Fault injection module | |||
| """ | |||
| import random | |||
| import numpy as np | |||
| import mindspore | |||
| from mindspore import ops, Tensor | |||
| from mindarmour.reliability.model_fault_injection.fault_type import FaultType | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.utils._check_param import check_int_positive | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'FaultInjector' | |||
| class FaultInjector: | |||
| """ | |||
| Fault injection for deep neural networks and evaluate performance. | |||
| Args: | |||
| model (Model): The model need to be evaluated. | |||
| data (Dataset): The data for testing. The evaluation is base on this data. | |||
| fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), | |||
| bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. | |||
| fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. | |||
| fi_size (list): The number of fault injection.It mean that how many values need to be injected. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> model = Model(net) | |||
| >>> ds_eval = create_dataloader() | |||
| >>> fi_type = ['bitflips_random', 'zeros'] | |||
| >>> fi_mode = ['single_layer', 'all_layer'] | |||
| >>> fi_size = [1, 2, 3] | |||
| >>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) | |||
| >>> fi.kick_off() | |||
| """ | |||
| def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): | |||
| """FaultInjector initiated.""" | |||
| self.running_list = [] | |||
| self._init_running_list(fi_type, fi_mode, fi_size) | |||
| self.model = model | |||
| self.data = data | |||
| self._fault_type = FaultType() | |||
| self._check_param() | |||
| self.result_list = [] | |||
| self.original_acc = 0 | |||
| self.original_parameter = {} | |||
| self.argmax = ops.Argmax() | |||
| self._reducesum = ops.ReduceSum(keep_dims=False) | |||
| self._frozen() | |||
| def _check_param(self): | |||
| """Check input parameters.""" | |||
| attr = self._fault_type.__dir__() | |||
| if not isinstance(self.data, mindspore.dataset.Dataset): | |||
| msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| _ = check_int_positive('dataset_size', self.data.get_dataset_size()) | |||
| if not isinstance(self.model, mindspore.Model): | |||
| msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) | |||
| LOGGER.error(TAG, msg) | |||
| raise TypeError(msg) | |||
| for param in self.running_list: | |||
| if param['fi_type'] not in attr: | |||
| msg = "'Undefined fault type', got {}.".format(param['fi_type']) | |||
| LOGGER.error(TAG, msg) | |||
| raise AttributeError(msg) | |||
| if param['fi_mode'] not in ['single_layer', 'all_layer']: | |||
| msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode']) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| _ = check_int_positive('fi_size', param['fi_size']) | |||
| def _init_running_list(self, type_, mode_, size_): | |||
| """Initiate fault injection parameters of this evaluation.""" | |||
| if type_ is None: | |||
| type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', | |||
| 'anti_activation', 'precision_loss'] | |||
| if mode_ is None: | |||
| mode_ = ['single_layer', 'all_layer'] | |||
| if size_ is None: | |||
| size_ = list(range(1, 4)) | |||
| for i in type_: | |||
| i = i if i.startswith('_') else '_' + i | |||
| for j in mode_: | |||
| for k in size_: | |||
| dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} | |||
| self.running_list.append(dict_) | |||
| def _frozen(self): | |||
| """Store original parameters of model.""" | |||
| trainable_param = self.model.predict_network.trainable_params() | |||
| for param in trainable_param: | |||
| np_param = param.asnumpy().copy() | |||
| bytes_ = np_param.tobytes() | |||
| self.original_parameter[param.name] = {} | |||
| self.original_parameter[param.name]['datatype'] = np_param.dtype | |||
| self.original_parameter[param.name]['shape'] = np_param.shape | |||
| self.original_parameter[param.name]['data'] = bytes_.hex() | |||
| def _reset_model(self): | |||
| """Reset model with original parameters.""" | |||
| for weight in self.model.predict_network.trainable_params(): | |||
| name = weight.name | |||
| if name in self.original_parameter.keys(): | |||
| bytes_w = bytes.fromhex(self.original_parameter[name]['data']) | |||
| datatype_w = self.original_parameter[name]['datatype'] | |||
| shape_w = self.original_parameter[name]['shape'] | |||
| np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w) | |||
| weight.assign_value(Tensor.from_numpy(np_w)) | |||
| else: | |||
| msg = "Layer name not matched, got {}.".format(name) | |||
| LOGGER.error(TAG, msg) | |||
| raise KeyError(msg) | |||
| def kick_off(self): | |||
| """ | |||
| Startup and return final results. | |||
| Returns: | |||
| list, the result of fault injection. | |||
| """ | |||
| result_list = [] | |||
| for i in range(-1, len(self.running_list)): | |||
| arg = self.running_list[i] | |||
| total = 0 | |||
| correct = 0 | |||
| for data in self.data.create_dict_iterator(): | |||
| batch = data['image'] | |||
| label = data['label'] | |||
| if i != -1: | |||
| self._reset_model() | |||
| self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) | |||
| output = self.model.predict(batch) | |||
| predict = self.argmax(output) | |||
| mask = predict == label | |||
| total += predict.size | |||
| correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() | |||
| acc = correct / total | |||
| if i == -1: | |||
| self.original_acc = acc | |||
| result_list.append({'original_acc': self.original_acc}) | |||
| else: | |||
| result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], | |||
| 'acc': acc, 'SDC': self.original_acc - acc}) | |||
| self.data.reset() | |||
| self._reset_model() | |||
| self.result_list = result_list | |||
| return result_list | |||
| def metrics(self): | |||
| """metrics of final result.""" | |||
| result_summary = [] | |||
| single_layer_acc = [] | |||
| single_layer_sdc = [] | |||
| all_layer_acc = [] | |||
| all_layer_sdc = [] | |||
| for result in self.result_list: | |||
| if 'mode' in result.keys(): | |||
| if result['mode'] == 'single_layer': | |||
| single_layer_acc.append(float(result['acc'])) | |||
| single_layer_sdc.append(float(result['SDC'])) | |||
| else: | |||
| all_layer_acc.append(float(result['acc'])) | |||
| all_layer_sdc.append(float(result['SDC'])) | |||
| s_acc = np.array(single_layer_acc) | |||
| s_sdc = np.array(single_layer_sdc) | |||
| a_acc = np.array(all_layer_acc) | |||
| a_sdc = np.array(all_layer_sdc) | |||
| if single_layer_acc: | |||
| result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f' | |||
| % (np.mean(s_acc), np.max(s_acc), np.min(s_acc))) | |||
| result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f' | |||
| % (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc))) | |||
| if all_layer_acc: | |||
| result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f' | |||
| % (np.mean(a_acc), np.max(a_acc), np.min(a_acc))) | |||
| result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f' | |||
| % (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc))) | |||
| return result_summary | |||
| def _layer_states(self, fi_type, fi_mode, fi_size): | |||
| """FI in layer states.""" | |||
| # Choose a random layer for injection | |||
| if fi_mode == "single_layer": | |||
| # Single layer fault injection mode | |||
| random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)] | |||
| elif fi_mode == "all_layer": | |||
| # Multiple layer fault injection mode | |||
| random_num = list(range(len(self.model.predict_network.trainable_params()) - 1)) | |||
| else: | |||
| msg = 'undefined fi_mode {}'.format(fi_mode) | |||
| LOGGER.error(TAG, msg) | |||
| raise ValueError(msg) | |||
| for n in random_num: | |||
| # Get layer states info | |||
| w = self.model.predict_network.trainable_params()[n] | |||
| w_np = w.asnumpy().copy() | |||
| elem_shape = w_np.shape | |||
| w_np = w_np.reshape(-1) | |||
| # fault inject | |||
| new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size) | |||
| # Reshape into original dimensions and store the faulty tensor | |||
| new_w_np = np.reshape(new_w_np, elem_shape) | |||
| w.set_data(Tensor.from_numpy(new_w_np)) | |||
| @@ -0,0 +1,219 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Fault type module | |||
| """ | |||
| import math | |||
| import random | |||
| from struct import pack, unpack | |||
| import numpy as np | |||
| from mindarmour.utils.logger import LogUtil | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'FaultType' | |||
| class FaultType: | |||
| """Implementation of specified fault type.""" | |||
| @staticmethod | |||
| def _bitflip(value, pos): | |||
| """ | |||
| Implement of bitflip. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| pos (list): The index of flip position. | |||
| Returns: | |||
| numpy.ndarray, bitflip data. | |||
| """ | |||
| bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||
| value_format = 'B' * int(int(bits) / 8) | |||
| value_bytes = value.tobytes() | |||
| bytes_ = list(unpack(value_format, value_bytes)) | |||
| for p in pos: | |||
| [q, r] = divmod(p, 8) | |||
| bytes_[q] ^= 1 << r | |||
| new_value_bytes = pack(value_format, *bytes_) | |||
| new_value = np.frombuffer(new_value_bytes, value.dtype) | |||
| return new_value[0] | |||
| def _fault_inject(self, value, fi_type, fi_size): | |||
| """ | |||
| Inject the specified fault into the randomly chosen values. | |||
| For zeros, anti_activation and precision_loss, fi_size is the percentage of | |||
| total number. And the others fault, fi_size is the exact number of values to | |||
| be injected. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_type (str): Fault type. | |||
| fi_size (int): The number of fault injection. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| num = value.size | |||
| if fi_type in ['zeros', 'anti_activation', 'precision_loss']: | |||
| change_size = (fi_size * num) / 100 | |||
| change_size = math.floor(change_size) | |||
| else: | |||
| change_size = fi_size | |||
| if change_size > num: | |||
| change_size = num | |||
| # Choose the indices for FI | |||
| ind = random.sample(range(num), change_size) | |||
| # got specified fault type | |||
| try: | |||
| func = getattr(self, fi_type) | |||
| value = func(value, ind) | |||
| return value | |||
| except AttributeError: | |||
| msg = "'Undefined fault type', got {}.".format(fi_type) | |||
| LOGGER.error(TAG, msg) | |||
| raise AttributeError(msg) | |||
| def _bitflips_random(self, value, fi_indices): | |||
| """ | |||
| Flip bit randomly for specified value. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| for item in fi_indices: | |||
| val = value[item] | |||
| pos = random.sample(range(int(str(val.dtype)[-2:])), | |||
| 1 if np.random.random() < 0.618 else 2) | |||
| val_new = self._bitflip(val, pos) | |||
| value[item] = val_new | |||
| return value | |||
| def _bitflips_designated(self, value, fi_indices): | |||
| """ | |||
| Flip the key bit for specified value. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| for item in fi_indices: | |||
| val = value[item] | |||
| # uint8 uint16 uint32 uint64 | |||
| bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||
| if 'uint' in str(val.dtype): | |||
| pos = int(bits) - 1 | |||
| # int8 int16 int32 int64 float16 float32 float64 | |||
| else: | |||
| pos = int(bits) - 2 | |||
| val_new = self._bitflip(val, [pos]) | |||
| value[item] = val_new | |||
| return value | |||
| @staticmethod | |||
| def _random(value, fi_indices): | |||
| """ | |||
| Reset specified value randomly, range from -1 to 1. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| for item in fi_indices: | |||
| value[item] = np.random.random() * 2 - 1 | |||
| return value | |||
| @staticmethod | |||
| def _zeros(value, fi_indices): | |||
| """ | |||
| Set specified value into zeros. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| value[fi_indices] = 0. | |||
| return value | |||
| @staticmethod | |||
| def _nan(value, fi_indices): | |||
| """ | |||
| Set specified value into nan. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| try: | |||
| value[fi_indices] = np.nan | |||
| return value | |||
| except ValueError: | |||
| return value | |||
| @staticmethod | |||
| def _inf(value, fi_indices): | |||
| """ | |||
| Set specified value into inf | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| try: | |||
| value[fi_indices] = np.inf | |||
| return value | |||
| except OverflowError: | |||
| return value | |||
| @staticmethod | |||
| def _anti_activation(value, fi_indices): | |||
| """ | |||
| Minus specified value. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| value[fi_indices] = -value[fi_indices] | |||
| return value | |||
| @staticmethod | |||
| def _precision_loss(value, fi_indices): | |||
| """ | |||
| Round specified value, round to 1 decimal place. | |||
| Args: | |||
| value (numpy.ndarray): Input data. | |||
| fi_indices (list): The index of injected data. | |||
| Returns: | |||
| numpy.ndarray, data after fault injection. | |||
| """ | |||
| value[fi_indices] = np.around(value[fi_indices], decimals=1) | |||
| return value | |||
| @@ -0,0 +1,240 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Test for fault injection. | |||
| """ | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import Model | |||
| import mindspore.dataset as ds | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||
| from tests.ut.python.utils.mock_net import Net | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fault injection test' | |||
| LOGGER.set_level('INFO') | |||
| def dataset_generator(): | |||
| """mock training data.""" | |||
| batch_size = 32 | |||
| batches = 128 | |||
| data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||
| np.float32) | |||
| label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||
| for i in range(batches): | |||
| yield data[i*batch_size:(i + 1)*batch_size],\ | |||
| label[i*batch_size:(i + 1)*batch_size] | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_fault_injector(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Run kick_off and metrics successfully | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1] | |||
| # Fault injection | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_wrong_model(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Throw TypeError exception | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1] | |||
| # Fault injection | |||
| with pytest.raises(TypeError) as exc_info: | |||
| fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| assert exc_info.type is TypeError | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_wrong_data(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Throw TypeError exception | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1] | |||
| # Fault injection | |||
| with pytest.raises(TypeError) as exc_info: | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| assert exc_info.type is TypeError | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_wrong_fi_type(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Throw AttributeError exception | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
| fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [1] | |||
| # Fault injection | |||
| with pytest.raises(AttributeError) as exc_info: | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| assert exc_info.type is AttributeError | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_wrong_fi_mode(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Throw ValueError exception | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer_tail', 'all_layer'] | |||
| fi_size = [1] | |||
| # Fault injection | |||
| with pytest.raises(ValueError) as exc_info: | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| assert exc_info.type is ValueError | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_wrong_fi_size(): | |||
| """ | |||
| Feature: Fault injector | |||
| Description: Test fault injector | |||
| Expectation: Throw ValueError exception | |||
| """ | |||
| # load model | |||
| ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = Net() | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| model = Model(net) | |||
| ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
| fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
| 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
| fi_mode = ['single_layer', 'all_layer'] | |||
| fi_size = [-1] | |||
| # Fault injection | |||
| with pytest.raises(ValueError) as exc_info: | |||
| fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
| _ = fi.kick_off() | |||
| _ = fi.metrics() | |||
| assert exc_info.type is ValueError | |||