|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import unittest
-
- import torch
-
- from modelscope.outputs import TextClassificationModelOutput
- from modelscope.utils.test_utils import test_level
-
-
- class TestModelOutput(unittest.TestCase):
-
- def setUp(self):
- pass
-
- def tearDown(self):
- super().tearDown()
-
- @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
- def test_model_outputs(self):
- outputs = TextClassificationModelOutput(logits=torch.Tensor([1]))
- self.assertEqual(outputs['logits'], torch.Tensor([1]))
- self.assertEqual(outputs[0], torch.Tensor([1]))
- self.assertEqual(outputs.logits, torch.Tensor([1]))
- outputs.loss = torch.Tensor([2])
- logits, loss = outputs
- self.assertEqual(logits, torch.Tensor([1]))
- self.assertTrue(loss is not None)
-
-
- if __name__ == '__main__':
- unittest.main()
|