You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_classification.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from maas_lib.fileio import File
  8. from maas_lib.models.nlp import SequenceClassificationModel
  9. from maas_lib.pipelines import SequenceClassificationPipeline, pipeline
  10. from maas_lib.preprocessors import SequenceClassificationPreprocessor
  11. class SequenceClassificationTest(unittest.TestCase):
  12. def predict(self, pipeline: SequenceClassificationPipeline):
  13. from easynlp.appzoo import load_dataset
  14. set = load_dataset('glue', 'sst2')
  15. data = set['test']['sentence'][:3]
  16. results = pipeline(data[0])
  17. print(results)
  18. results = pipeline(data[1])
  19. print(results)
  20. print(data)
  21. def test_run(self):
  22. model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
  23. '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
  24. with tempfile.TemporaryDirectory() as tmp_dir:
  25. tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip')
  26. with open(tmp_file, 'wb') as ofile:
  27. ofile.write(File.read(model_url))
  28. with zipfile.ZipFile(tmp_file, 'r') as zipf:
  29. zipf.extractall(tmp_dir)
  30. path = osp.join(tmp_dir, 'bert-base-sst2')
  31. print(path)
  32. model = SequenceClassificationModel(path)
  33. preprocessor = SequenceClassificationPreprocessor(
  34. path, first_sequence='sentence', second_sequence=None)
  35. pipeline1 = SequenceClassificationPipeline(model, preprocessor)
  36. self.predict(pipeline1)
  37. pipeline2 = pipeline(
  38. 'text-classification', model=model, preprocessor=preprocessor)
  39. print(pipeline2('Hello world!'))
  40. if __name__ == '__main__':
  41. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展