Browse Source

Add API Examples for Reliability Module

tags/v1.8.0
shu-kun-zhang 3 years ago
parent
commit
6ff9f986bb
3 changed files with 70 additions and 15 deletions
  1. +38
    -10
      mindarmour/reliability/concept_drift/concept_drift_check_images.py
  2. +2
    -0
      mindarmour/reliability/concept_drift/concept_drift_check_time_series.py
  3. +30
    -5
      mindarmour/reliability/model_fault_injection/fault_injection.py

+ 38
- 10
mindarmour/reliability/concept_drift/concept_drift_check_images.py View File

@@ -13,18 +13,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================


"""
Out-of-Distribution detection module for images.
"""


import heapq import heapq
import numpy as np import numpy as np
from mindspore import Tensor
from sklearn.cluster import KMeans from sklearn.cluster import KMeans
from mindarmour.utils._check_param import check_param_type, check_param_in_range
from mindspore import Tensor
from mindspore.train.summary.summary_record import _get_summary_tensor_data from mindspore.train.summary.summary_record import _get_summary_tensor_data


"""
Out-of-Distribution detection for images.
"""
from mindarmour.utils._check_param import check_param_type, check_param_in_range




class OodDetector: class OodDetector:
@@ -67,7 +65,7 @@ class OodDetector:
Returns: Returns:
- float, the optimal threshold. - float, the optimal threshold.
""" """
pass


def ood_predict(self, threshold, ds_test): def ood_predict(self, threshold, ds_test):
""" """
@@ -81,7 +79,6 @@ class OodDetector:
Returns: Returns:
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. - numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood.
""" """
pass




class OodDetectorFeatureCluster(OodDetector): class OodDetectorFeatureCluster(OodDetector):
@@ -90,6 +87,8 @@ class OodDetectorFeatureCluster(OodDetector):
the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD) the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD)
image or not. image or not.


For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_images.html>`_

Args: Args:
model (Model):The training model. model (Model):The training model.
ds_train (numpy.ndarray): The training dataset. ds_train (numpy.ndarray): The training dataset.
@@ -100,9 +99,38 @@ class OodDetectorFeatureCluster(OodDetector):
layer (str): The name of the feature layer. layer (str) is represented by layer (str): The name of the feature layer. layer (str) is represented by
'name[:Tensor]', where 'name' is given by users when training the model. 'name[:Tensor]', where 'name' is given by users when training the model.
Please see more details about how to name the model layer in 'README.md'. Please see more details about how to name the model layer in 'README.md'.

Examples:
>>> from mindspore import Model
>>> from mindspore.ops import TensorSummary
>>> import mindspore.ops.operations as P
>>> from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._Dense = nn.Dense(10,10)
... self._squeeze = P.Squeeze(1)
... self._summary = TensorSummary()
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._Dense(out)
... self._summary('output', out)
... return self._squeeze(out)
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 16
>>> batches = 1
>>> ds_train = np.random.randn(batches * batch_size, 1, 10).astype(np.float32)
>>> ds_eval = np.random.randn(batches * batch_size, 1, 10).astype(np.float32)
>>> detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]')
>>> num = int(len(ds_eval) / 2)
>>> ood_label = np.concatenate((np.zeros(num), np.ones(num)), axis=0)
>>> optimal_threshold = detector.get_optimal_threshold(ood_label, ds_eval)
""" """


def __init__(self, model, ds_train, n_cluster, layer): def __init__(self, model, ds_train, n_cluster, layer):
super(OodDetectorFeatureCluster, self).__init__(model, ds_train)
self.model = model self.model = model
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) self.ds_train = check_param_type('ds_train', ds_train, np.ndarray)
self.n_cluster = check_param_type('n_cluster', n_cluster, int) self.n_cluster = check_param_type('n_cluster', n_cluster, int)
@@ -173,7 +201,7 @@ class OodDetectorFeatureCluster(OodDetector):
threshold.append(threshold_change) threshold.append(threshold_change)
acc = np.array(acc) acc = np.array(acc)
threshold = np.array(threshold) threshold = np.array(threshold)
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0]
optimal_threshold = threshold[np.where(acc == np.max(acc))[0]][0]
return optimal_threshold return optimal_threshold


def ood_predict(self, threshold, ds_test): def ood_predict(self, threshold, ds_test):


+ 2
- 0
mindarmour/reliability/concept_drift/concept_drift_check_time_series.py View File

@@ -25,6 +25,7 @@ from mindarmour.utils._check_param import check_param_type, check_param_in_range
class ConceptDriftCheckTimeSeries: class ConceptDriftCheckTimeSeries:
r""" r"""
ConceptDriftCheckTimeSeries is used for example series distribution change detection. ConceptDriftCheckTimeSeries is used for example series distribution change detection.
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_time_series.html>`_


Args: Args:
window_size(int): Size of a concept window, no less than 10. If given the input data, window_size(int): Size of a concept window, no less than 10. If given the input data,
@@ -38,6 +39,7 @@ class ConceptDriftCheckTimeSeries:
Default: False. Default: False.


Examples: Examples:
>>> from mindarmour import ConceptDriftCheckTimeSeries
>>> concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, >>> concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10,
... step=10, threshold_index=1.5, need_label=False) ... step=10, threshold_index=1.5, need_label=False)
>>> data_example = 5*np.random.rand(1000) >>> data_example = 5*np.random.rand(1000)


+ 30
- 5
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -31,6 +31,7 @@ TAG = 'FaultInjector'
class FaultInjector: class FaultInjector:
""" """
Fault injection for deep neural networks and evaluate performance. Fault injection for deep neural networks and evaluate performance.
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/fault_injection.html>`_


Args: Args:
model (Model): The model need to be evaluated. model (Model): The model need to be evaluated.
@@ -40,14 +41,38 @@ class FaultInjector:
fi_size (list): The number of fault injection.It mean that how many values need to be injected. fi_size (list): The number of fault injection.It mean that how many values need to be injected.


Examples: Examples:
>>> from mindspore import Model
>>> import mindspore.dataset as ds
>>> import mindspore.ops.operations as P
>>> from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self._softmax = P.Softmax()
... self._Dense = nn.Dense(10,10)
... self._squeeze = P.Squeeze(1)
... def construct(self, inputs):
... out = self._softmax(inputs)
... out = self._Dense(out)
... return self._squeeze(out)
>>> def dataset_generator():
... batch_size = 16
... batches = 1
... data = np.random.randn(batches * batch_size,1,10).astype(np.float32)
... label = np.random.randint(0,10, batches * batch_size).astype(np.int32)
... for i in range(batches):
... yield data[i*batch_size:(i+1)*batch_size],\
... label[i*batch_size:(i+1)*batch_size]
>>> net = Net() >>> net = Net()
>>> model = Model(net) >>> model = Model(net)
>>> ds_data, ds_label = create_data()
>>> fi_type = ['bitflips_random', 'zeros']
>>> ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
>>> fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
>>> fi_mode = ['single_layer', 'all_layer'] >>> fi_mode = ['single_layer', 'all_layer']
>>> fi_size = [1, 2]
>>> fi = FaultInjector(model, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off(ds_data, ds_label)
>>> fi_size = [1]
>>> fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
>>> fi.kick_off()
>>> fi.metrics()
""" """


def __init__(self, model, fi_type=None, fi_mode=None, fi_size=None): def __init__(self, model, fi_type=None, fi_mode=None, fi_size=None):


Loading…
Cancel
Save