|
- from fastNLP.core.metrics.metric import Metric
-
- from collections import defaultdict
- from functools import partial
-
- import unittest
-
-
- class MyMetric(Metric):
-
- def __init__(self, backend='auto',
- aggregate_when_get_metric: bool = False):
- super(MyMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
-
- self.tp = defaultdict(partial(self.register_element, aggregate_method='sum'))
-
- def update(self, item):
- self.tp['1'] += item
-
-
- class TestMetric(unittest.TestCase):
-
- def test_va1(self):
- my = MyMetric()
- my.update(1)
- print(my.tp['1'])
|