You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_export_sbert_sequence_classification.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.exporters import Exporter, TorchModelExporter
  7. from modelscope.models.base import Model
  8. from modelscope.utils.test_utils import test_level
  9. class TestExportSbertSequenceClassification(unittest.TestCase):
  10. def setUp(self):
  11. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  12. self.tmp_dir = tempfile.TemporaryDirectory().name
  13. if not os.path.exists(self.tmp_dir):
  14. os.makedirs(self.tmp_dir)
  15. self.model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
  16. def tearDown(self):
  17. shutil.rmtree(self.tmp_dir)
  18. super().tearDown()
  19. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  20. def test_export_sbert_sequence_classification(self):
  21. model = Model.from_pretrained(self.model_id)
  22. print(
  23. Exporter.from_model(model).export_onnx(
  24. shape=(2, 256), outputs=self.tmp_dir))
  25. print(
  26. TorchModelExporter.from_model(model).export_torch_script(
  27. shape=(2, 256), outputs=self.tmp_dir))
  28. if __name__ == '__main__':
  29. unittest.main()