|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import tensorflow as tf
- from tensorflow.keras.metrics import Metric
-
- __all__ = [
- 'Accuracy',
- 'Auc',
- 'Precision',
- 'Recall',
- ]
-
-
- class Accuracy(object):
-
- def __init__(self, topk=1):
- self.topk = topk
- if topk == 1:
- self.accuary = tf.keras.metrics.Accuracy()
- else:
- self.accuary = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=topk)
-
- def update(self, y_pred, y_true):
-
- if self.topk == 1:
- y_pred = tf.argmax(y_pred, axis=1)
- self.accuary.update_state(y_true, y_pred)
- else:
- self.accuary.update_state(y_true, y_pred)
-
- def result(self):
-
- return self.accuary.result()
-
- def reset(self):
-
- self.accuary.reset_states()
-
-
- class Auc(object):
-
- def __init__(
- self,
- curve='ROC',
- num_thresholds=200,
- ):
- self.auc = tf.keras.metrics.AUC(num_thresholds=num_thresholds, curve=curve)
-
- def update(self, y_pred, y_true):
-
- self.auc.update_state(y_true, y_pred)
-
- def result(self):
-
- return self.auc.result()
-
- def reset(self):
-
- self.auc.reset_states()
-
-
- class Precision(object):
-
- def __init__(self):
-
- self.precision = tf.keras.metrics.Precision()
-
- def update(self, y_pred, y_true):
-
- self.precision.update_state(y_true, y_pred)
-
- def result(self):
-
- return self.precision.result()
-
- def reset(self):
-
- self.precision.reset_states()
-
-
- class Recall(object):
-
- def __init__(self):
-
- self.recall = tf.keras.metrics.Recall()
-
- def update(self, y_pred, y_true):
-
- self.recall.update_state(y_true, y_pred)
-
- def result(self):
-
- return self.recall.result()
-
- def reset(self):
-
- self.recall.reset_states()
|