Browse Source

Pre Merge pull request !349 from ye12121/master

pull/349/MERGE
ye12121 Gitee 3 years ago
parent
commit
6330824ee5
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 230 additions and 109 deletions
  1. +104
    -103
      examples/reliability/model_fault_injection.py
  2. +42
    -0
      mindarmour/reliability/model_fault_injection/fault_injection.py
  3. +77
    -0
      mindarmour/reliability/model_fault_injection/metrics.py
  4. +7
    -6
      tests/ut/python/reliability/model_fault_injection/test_fault_injection.py

+ 104
- 103
examples/reliability/model_fault_injection.py View File

@@ -1,103 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Fault injection example.
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint.
Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz.
File structure:
--cifar10-batches-bin
--train
--data_batch_1.bin
--data_batch_2.bin
--data_batch_3.bin
--data_batch_4.bin
--data_batch_5.bin
--test
--test_batch.bin

Please extract and restructure the file as shown above.
"""
import argparse

import numpy as np
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint

from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.vgg.vgg import vgg16
from examples.common.networks.resnet.resnet import resnet50
from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset


parser = argparse.ArgumentParser(description='layer_states')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16'])
parser.add_argument('--device_id', type=int, default=0)
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)


test_flag = args.model
if test_flag == 'lenet':
# load data
DATA_FILE = '../common/dataset/MNIST_Data/test'
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64)
net = LeNet5()
elif test_flag == 'vgg16':
from examples.common.networks.vgg.config import cifar_cfg as cfg
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = vgg16(10, cfg, 'test')
elif test_flag == 'resnet50':
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = resnet50(10)
else:
exit()

test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)

param_dict = load_checkpoint(ckpt_path, net=net)
model = Model(net)

# Initialization
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3]

# Fault injection
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
results = fi.kick_off(ds_data, ds_label, iter_times=100)
result_summary = fi.metrics()

# print result
for result in results:
print(result)
for result in result_summary:
print(result)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fault injection example.
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint.
Download dataset from: <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>.
File structure:
--cifar10-batches-bin
--train
--data_batch_1.bin
--data_batch_2.bin
--data_batch_3.bin
--data_batch_4.bin
--data_batch_5.bin
--test
--test_batch.bin
Please extract and restructure the file as shown above.
"""
import argparse
import numpy as np
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from mindarmour.reliability.model_fault_injection.metrics import ClassifierMetrics
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.vgg.vgg import vgg16
from examples.common.networks.resnet.resnet import resnet50
from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset
parser = argparse.ArgumentParser(description='layer_states')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16'])
parser.add_argument('--device_id', type=int, default=0)
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
test_flag = args.model
if test_flag == 'lenet':
# load data
DATA_FILE = '../common/dataset/MNIST_Data/test'
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64)
net = LeNet5()
elif test_flag == 'vgg16':
from examples.common.networks.vgg.config import cifar_cfg as cfg
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = vgg16(10, cfg, 'test')
elif test_flag == 'resnet50':
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = resnet50(10)
else:
exit()
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
param_dict = load_checkpoint(ckpt_path, net=net)
model = Model(net)
# Initialization
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3]
# Fault injection
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
results = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
result_summary = fi.metrics()
# print result
for result in results:
print(result)
for result in result_summary:
print(result)

+ 42
- 0
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -195,6 +195,48 @@ class FaultInjector:
_ = check_param_type('ds_label', ds_label, np.ndarray)
_ = _check_array_not_empty('ds_label', ds_label)

def run(self, ds_data, ds_label, metrics_class, iter_times=100, **kwargs):
"""
Startup and return final results after Fault Injection.

Args:
ds_data(np.ndarray): Input data for testing. The evaluation is based on this data.
ds_label(np.ndarray): The label of data, corresponding to the data.
metrics_class(BaseMetrics): The specific class of metrics.
iter_times(int): The number of evaluations, which will determine the batch size.

Returns:
- list, the result of fault injection.
"""
self._check_kick_off_param(ds_data, ds_label, iter_times)
num = ds_data.shape[0]
idx_list = self._calculate_batch_size(num, iter_times)
result_list = []
for i in range(-1, len(self.running_list)):
metrics_object = metrics_class(**kwargs)
arg = self.running_list[i]
for idx in range(len(idx_list) - 1):
a = ds_data[idx_list[idx]:idx_list[idx + 1], ...]
batch = Tensor.from_numpy(a)
label = Tensor.from_numpy(ds_label[idx_list[idx]:idx_list[idx + 1], ...])
if label.ndim == 2:
label = self.argmax(label)
if i != -1:
self._reset_model()
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size'])
output = self.model.predict(batch)
metrics_object.collect(output.asnumpy(), label.asnumpy())
acc = metrics_object.metrics()
if i == -1:
self.original_acc = acc
result_list.append({'original_acc': self.original_acc})
else:
result_list.append({'type': arg['fi_type'][1:], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
'acc': acc, 'SDC': self.original_acc - acc})
self._reset_model()
self.result_list = result_list
return result_list

def kick_off(self, ds_data, ds_label, iter_times=100):
"""
Startup and return final results after Fault Injection.


+ 77
- 0
mindarmour/reliability/model_fault_injection/metrics.py View File

@@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Metrics module
"""
from abc import abstractmethod
import numpy as np
class BaseMetrics:
"""The abstract base class for metrics classes"""
@abstractmethod
def collect(self, outputs, labels):
"""
Collect and process the outputs of the model each inference.
Args:
outputs (numpy.ndarray): The outputs of models.
labels (numpy.ndarray): The labels of data.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = "The function collect() is an abstract method in class 'BaseMetrics', " \
"and should be implemented in child class"
raise NotImplementedError(msg)
@abstractmethod
def metrics(self):
"""
Metrics the final result.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = "The function metrics() is an abstract method in class 'BaseMetrics', " \
"and should be implemented in child class"
raise NotImplementedError(msg)
class ClassifierMetrics(BaseMetrics):
"""Implementation of classifier metrics."""
def __init__(self):
"""Initiated."""
self.total_num = 0
self.correct_num = 0
self.acc = 0
def collect(self, outputs, labels):
"""
Collect and process the outputs of the model each inference.
Args:
outputs (numpy.ndarray): The outputs of models.
labels (numpy.ndarray): The labels of data.
"""
predict = np.argmax(outputs, axis=1)
mask = np.equal(predict, labels)
self.total_num += mask.shape[0]
self.correct_num += np.sum(mask)
def metrics(self):
"""
Metrics the final result.
Returns:
float : Accuracy.
"""
self.acc = self.correct_num / self.total_num
return self.acc

+ 7
- 6
tests/ut/python/reliability/model_fault_injection/test_fault_injection.py View File

@@ -26,6 +26,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from mindarmour.reliability.model_fault_injection.metrics import ClassifierMetrics
from tests.ut.python.utils.mock_net import Net
@@ -84,7 +85,7 @@ def test_fault_injector():
# Fault injection
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
@@ -126,7 +127,7 @@ def test_wrong_model():
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@@ -162,7 +163,7 @@ def test_wrong_data():
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@@ -206,7 +207,7 @@ def test_wrong_fi_type():
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError
@@ -250,7 +251,7 @@ def test_wrong_fi_mode():
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError
@@ -295,6 +296,6 @@ def test_wrong_fi_size():
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.run(ds_data, ds_label, ClassifierMetrics, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError

Loading…
Cancel
Save