You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_classification_metrics.py 1.0 kB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. import numpy as np
  4. from modelscope.metrics.sequence_classification_metric import \
  5. SequenceClassificationMetric
  6. from modelscope.utils.test_utils import test_level
  7. class TestTextClsMetrics(unittest.TestCase):
  8. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  9. def test_value(self):
  10. metric = SequenceClassificationMetric()
  11. outputs = {
  12. 'logits':
  13. np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0],
  14. [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7],
  15. [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]])
  16. }
  17. inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])}
  18. metric.add(outputs, inputs)
  19. ret = metric.evaluate()
  20. self.assertTrue(np.isclose(ret['f1'], 0.5))
  21. self.assertTrue(np.isclose(ret['accuracy'], 0.5))
  22. print(ret)
  23. if __name__ == '__main__':
  24. unittest.main()