You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_metrics.py 471 B

1234567891011121314151617
  1. import unittest
  2. class TestOptim(unittest.TestCase):
  3. def test_AccuracyMetric(self):
  4. from fastNLP.core.metrics import AccuracyMetric
  5. import torch
  6. import numpy as np
  7. # (1) only input, targets passed
  8. output_dict = {"input": torch.zeros(4, 3)}
  9. target_dict = {'target': torch.zeros(4)}
  10. metric = AccuracyMetric()
  11. metric(output_dict=output_dict, target_dict=target_dict)
  12. print(metric.get_metric())