You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_generation.py 3.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from maas_hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import PalmForTextGeneration
  6. from modelscope.pipelines import TextGenerationPipeline, pipeline
  7. from modelscope.preprocessors import TextGenerationPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.test_utils import test_level
  10. class TextGenerationTest(unittest.TestCase):
  11. model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
  12. model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
  13. input_zh = """
  14. 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
  15. 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
  16. """
  17. input_en = """
  18. The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
  19. her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
  20. 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
  21. of paedophilia against nine children because he has dementia . Today , newly-released documents
  22. revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
  23. And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
  24. pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
  25. """
  26. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  27. def test_run(self):
  28. for model_id, input in ((self.model_id_zh, self.input_zh),
  29. (self.model_id_en, self.input_en)):
  30. cache_path = snapshot_download(model_id)
  31. model = PalmForTextGeneration(cache_path)
  32. preprocessor = TextGenerationPreprocessor(
  33. cache_path,
  34. model.tokenizer,
  35. first_sequence='sentence',
  36. second_sequence=None)
  37. pipeline1 = TextGenerationPipeline(model, preprocessor)
  38. pipeline2 = pipeline(
  39. Tasks.text_generation, model=model, preprocessor=preprocessor)
  40. print(
  41. f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
  42. )
  43. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  44. def test_run_with_model_from_modelhub(self):
  45. for model_id, input in ((self.model_id_zh, self.input_zh),
  46. (self.model_id_en, self.input_en)):
  47. model = Model.from_pretrained(model_id)
  48. preprocessor = TextGenerationPreprocessor(
  49. model.model_dir,
  50. model.tokenizer,
  51. first_sequence='sentence',
  52. second_sequence=None)
  53. pipeline_ins = pipeline(
  54. task=Tasks.text_generation,
  55. model=model,
  56. preprocessor=preprocessor)
  57. print(pipeline_ins(input))
  58. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  59. def test_run_with_model_name(self):
  60. for model_id, input in ((self.model_id_zh, self.input_zh),
  61. (self.model_id_en, self.input_en)):
  62. pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
  63. print(pipeline_ins(input))
  64. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  65. def test_run_with_default_model(self):
  66. pipeline_ins = pipeline(task=Tasks.text_generation)
  67. print(pipeline_ins(self.input_zh))
  68. if __name__ == '__main__':
  69. unittest.main()