You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_generation.py 3.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import PalmForTextGeneration
  6. from modelscope.pipelines import pipeline
  7. from modelscope.pipelines.nlp import TextGenerationPipeline
  8. from modelscope.preprocessors import TextGenerationPreprocessor
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.test_utils import test_level
  11. class TextGenerationTest(unittest.TestCase):
  12. model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
  13. model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
  14. input_zh = """
  15. 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
  16. 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
  17. """
  18. input_en = """
  19. The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
  20. her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
  21. 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
  22. of paedophilia against nine children because he has dementia . Today , newly-released documents
  23. revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
  24. And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
  25. pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
  26. """
  27. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  28. def test_run(self):
  29. for model_id, input in ((self.model_id_zh, self.input_zh),
  30. (self.model_id_en, self.input_en)):
  31. cache_path = snapshot_download(model_id)
  32. model = PalmForTextGeneration(cache_path)
  33. preprocessor = TextGenerationPreprocessor(
  34. cache_path,
  35. model.tokenizer,
  36. first_sequence='sentence',
  37. second_sequence=None)
  38. pipeline1 = TextGenerationPipeline(model, preprocessor)
  39. pipeline2 = pipeline(
  40. Tasks.text_generation, model=model, preprocessor=preprocessor)
  41. print(
  42. f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
  43. )
  44. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  45. def test_run_with_model_from_modelhub(self):
  46. for model_id, input in ((self.model_id_zh, self.input_zh),
  47. (self.model_id_en, self.input_en)):
  48. model = Model.from_pretrained(model_id)
  49. preprocessor = TextGenerationPreprocessor(
  50. model.model_dir,
  51. model.tokenizer,
  52. first_sequence='sentence',
  53. second_sequence=None)
  54. pipeline_ins = pipeline(
  55. task=Tasks.text_generation,
  56. model=model,
  57. preprocessor=preprocessor)
  58. print(pipeline_ins(input))
  59. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  60. def test_run_with_model_name(self):
  61. for model_id, input in ((self.model_id_zh, self.input_zh),
  62. (self.model_id_en, self.input_en)):
  63. pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
  64. print(pipeline_ins(input))
  65. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  66. def test_run_with_default_model(self):
  67. pipeline_ins = pipeline(task=Tasks.text_generation)
  68. print(pipeline_ins(self.input_zh))
  69. if __name__ == '__main__':
  70. unittest.main()