|
-
- import oneflow
- from kamal.core.metrics.stream_metrics import Metric
-
- __all__=['Accuracy']
-
- class Accuracy(Metric):
- def __init__(self, attach_to=None):
- super(Accuracy, self).__init__(attach_to=attach_to)
- self.reset()
-
- @oneflow.no_grad()
- def update(self, outputs, targets):
- outputs, targets = self._attach(outputs, targets)
- outputs = outputs.max(1)[1]
- self._correct += ( outputs.view(-1)==targets.view(-1) ).sum()
- self._cnt += oneflow.numel( targets )
-
- def get_results(self):
- return (self._correct / self._cnt).detach().cpu()
-
- def reset(self):
- self._correct = self._cnt = 0.0
|