|
- import unittest
- from fastNLP.core.metrics.utils import func_post_proc
-
-
- class Metric:
- def accumulate(self, x, y):
- return x, y
-
- def compute(self, x, y):
- return x, y
-
-
- class TestMetricUtil(unittest.TestCase):
- def test_func_post_proc(self):
- metric = Metric()
- metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate')
- self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2))
-
- func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate')
- self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2))
-
- metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
- self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2))
-
- func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update')
- self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2))
-
- def test_check_accumulate_post_special_local_variable(self):
- metric = Metric()
- self.assertFalse(hasattr(metric, '__wrapped_by_fn__'))
- metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
- self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))
|