You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_metric.py 1.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import mindspore.nn as nn
  4. from mindspore.nn.metrics._evaluation import EvaluationBase
  5. from mindspore.nn.metrics.metric import Metric
  6. __all__ = [
  7. 'Accuracy',
  8. 'Auc',
  9. 'Precision',
  10. 'Recall',
  11. ]
  12. class Accuracy(object):
  13. def __init__(self, topk=1):
  14. self.accuracy = nn.TopKCategoricalAccuracy(k=topk)
  15. def update(self, y_pred, y_true):
  16. self.accuracy.update(y_pred, y_true)
  17. def result(self):
  18. return self.accuracy.eval()
  19. def reset(self):
  20. self.accuracy.clear()
  21. class Auc(object):
  22. def __init__(self):
  23. pass
  24. def update(self, y_pred, y_true):
  25. raise Exception('Auc metric function not implemented')
  26. def result(self):
  27. pass
  28. def reset(self):
  29. pass
  30. class Precision(object):
  31. def __init__(self):
  32. self.precision = nn.Precision(eval_type="classification")
  33. def update(self, y_pred, y_true):
  34. self.precision.update(y_pred, y_true)
  35. def result(self):
  36. return self.precision.eval()
  37. def reset(self):
  38. self.precision.clear()
  39. class Recall(object):
  40. def __init__(self):
  41. self.recall = nn.Recall(eval_type="classification")
  42. def update(self, y_pred, y_true):
  43. self.recall.update(y_pred, y_true)
  44. def result(self):
  45. return self.recall.eval()
  46. def reset(self):
  47. self.recall.clear()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.