You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_generation.py 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from maas_hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import PalmForTextGenerationModel
  6. from modelscope.pipelines import TextGenerationPipeline, pipeline
  7. from modelscope.preprocessors import TextGenerationPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.test_utils import test_level
  10. class TextGenerationTest(unittest.TestCase):
  11. model_id = 'damo/nlp_palm_text-generation_chinese'
  12. input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'"
  13. input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'"
  14. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  15. def test_run(self):
  16. cache_path = snapshot_download(self.model_id)
  17. preprocessor = TextGenerationPreprocessor(
  18. cache_path, first_sequence='sentence', second_sequence=None)
  19. model = PalmForTextGenerationModel(
  20. cache_path, tokenizer=preprocessor.tokenizer)
  21. pipeline1 = TextGenerationPipeline(model, preprocessor)
  22. pipeline2 = pipeline(
  23. Tasks.text_generation, model=model, preprocessor=preprocessor)
  24. print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}')
  25. print()
  26. print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}')
  27. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  28. def test_run_with_model_from_modelhub(self):
  29. model = Model.from_pretrained(self.model_id)
  30. preprocessor = TextGenerationPreprocessor(
  31. model.model_dir, first_sequence='sentence', second_sequence=None)
  32. pipeline_ins = pipeline(
  33. task=Tasks.text_generation, model=model, preprocessor=preprocessor)
  34. print(pipeline_ins(self.input1))
  35. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  36. def test_run_with_model_name(self):
  37. pipeline_ins = pipeline(
  38. task=Tasks.text_generation, model=self.model_id)
  39. print(pipeline_ins(self.input2))
  40. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  41. def test_run_with_default_model(self):
  42. pipeline_ins = pipeline(task=Tasks.text_generation)
  43. print(pipeline_ins(self.input2))
  44. if __name__ == '__main__':
  45. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展