You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_metric.py 1.8 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. from tensorflow.keras.metrics import Metric
  5. __all__ = [
  6. 'Accuracy',
  7. 'Auc',
  8. 'Precision',
  9. 'Recall',
  10. ]
  11. class Accuracy(object):
  12. def __init__(self, topk=1):
  13. self.topk = topk
  14. if topk == 1:
  15. self.accuary = tf.keras.metrics.Accuracy()
  16. else:
  17. self.accuary = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=topk)
  18. def update(self, y_pred, y_true):
  19. if self.topk == 1:
  20. y_pred = tf.argmax(y_pred, axis=1)
  21. self.accuary.update_state(y_true, y_pred)
  22. else:
  23. self.accuary.update_state(y_true, y_pred)
  24. def result(self):
  25. return self.accuary.result()
  26. def reset(self):
  27. self.accuary.reset_states()
  28. class Auc(object):
  29. def __init__(
  30. self,
  31. curve='ROC',
  32. num_thresholds=200,
  33. ):
  34. self.auc = tf.keras.metrics.AUC(num_thresholds=num_thresholds, curve=curve)
  35. def update(self, y_pred, y_true):
  36. self.auc.update_state(y_true, y_pred)
  37. def result(self):
  38. return self.auc.result()
  39. def reset(self):
  40. self.auc.reset_states()
  41. class Precision(object):
  42. def __init__(self):
  43. self.precision = tf.keras.metrics.Precision()
  44. def update(self, y_pred, y_true):
  45. self.precision.update_state(y_true, y_pred)
  46. def result(self):
  47. return self.precision.result()
  48. def reset(self):
  49. self.precision.reset_states()
  50. class Recall(object):
  51. def __init__(self):
  52. self.recall = tf.keras.metrics.Recall()
  53. def update(self, y_pred, y_true):
  54. self.recall.update_state(y_true, y_pred)
  55. def result(self):
  56. return self.recall.result()
  57. def reset(self):
  58. self.recall.reset_states()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.