import oneflow from kamal.core.metrics.stream_metrics import Metric __all__=['Accuracy'] class Accuracy(Metric): def __init__(self, attach_to=None): super(Accuracy, self).__init__(attach_to=attach_to) self.reset() @oneflow.no_grad() def update(self, outputs, targets): outputs, targets = self._attach(outputs, targets) outputs = outputs.max(1)[1] self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() self._cnt += oneflow.numel( targets ) def get_results(self): return (self._correct / self._cnt).detach().cpu() def reset(self): self._correct = self._cnt = 0.0