You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_metric.py 1.5 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import paddle
  4. from paddle.metric.metrics import Metric
  5. __all__ = [
  6. 'Accuracy',
  7. 'Auc',
  8. 'Precision',
  9. 'Recall',
  10. ]
  11. class Accuracy(object):
  12. def __init__(
  13. self,
  14. topk=1,
  15. ):
  16. self.topk = topk
  17. self.accuracy = paddle.metric.Accuracy(topk=(self.topk, ))
  18. def update(self, y_pred, y_true):
  19. self.accuracy.update(self.accuracy.compute(y_pred, y_true))
  20. def result(self):
  21. return self.accuracy.accumulate()
  22. def reset(self):
  23. self.accuracy.reset()
  24. class Auc(object):
  25. def __init__(self, curve='ROC', num_thresholds=4095):
  26. self.auc = paddle.metric.Auc(curve=curve, num_thresholds=num_thresholds)
  27. def update(self, y_pred, y_true):
  28. self.auc.update(y_pred, y_true)
  29. def result(self):
  30. return self.auc.accumulate()
  31. def reset(self):
  32. self.auc.reset()
  33. class Precision(object):
  34. def __init__(self):
  35. self.precision = paddle.metric.Precision()
  36. def update(self, y_pred, y_true):
  37. self.precision.update(y_pred, y_true)
  38. def result(self):
  39. return self.precision.accumulate()
  40. def reset(self):
  41. self.precision.reset()
  42. class Recall(object):
  43. def __init__(self):
  44. self.recall = paddle.metric.Recall()
  45. def update(self, y_pred, y_true):
  46. self.recall.update(y_pred, y_true)
  47. def result(self):
  48. return self.recall.accumulate()
  49. def reset(self):
  50. self.recall.reset()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.