|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import mindspore.nn as nn
- from mindspore.nn.metrics._evaluation import EvaluationBase
- from mindspore.nn.metrics.metric import Metric
- __all__ = [
- 'Accuracy',
- 'Auc',
- 'Precision',
- 'Recall',
- ]
-
-
- class Accuracy(object):
-
- def __init__(self, topk=1):
-
- self.accuracy = nn.TopKCategoricalAccuracy(k=topk)
-
- def update(self, y_pred, y_true):
-
- self.accuracy.update(y_pred, y_true)
-
- def result(self):
-
- return self.accuracy.eval()
-
- def reset(self):
-
- self.accuracy.clear()
-
-
- class Auc(object):
-
- def __init__(self):
-
- pass
-
- def update(self, y_pred, y_true):
-
- raise Exception('Auc metric function not implemented')
-
- def result(self):
-
- pass
-
- def reset(self):
-
- pass
-
-
- class Precision(object):
-
- def __init__(self):
-
- self.precision = nn.Precision(eval_type="classification")
-
- def update(self, y_pred, y_true):
-
- self.precision.update(y_pred, y_true)
-
- def result(self):
-
- return self.precision.eval()
-
- def reset(self):
-
- self.precision.clear()
-
-
- class Recall(object):
-
- def __init__(self):
-
- self.recall = nn.Recall(eval_type="classification")
-
- def update(self, y_pred, y_true):
-
- self.recall.update(y_pred, y_true)
-
- def result(self):
-
- return self.recall.eval()
-
- def reset(self):
-
- self.recall.clear()
|