You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gpt3_text_generation.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. from modelscope.pipelines import pipeline
  5. from modelscope.utils.constant import Tasks
  6. from modelscope.utils.test_utils import test_level
  7. class TextGPT3GenerationTest(unittest.TestCase):
  8. def setUp(self) -> None:
  9. # please make sure this local path exists.
  10. self.model_id_1_3B = 'damo/nlp_gpt3_text-generation_1.3B'
  11. self.model_id_2_7B = 'damo/nlp_gpt3_text-generation_2.7B'
  12. self.model_id_13B = 'damo/nlp_gpt3_text-generation_13B'
  13. self.model_dir_13B = snapshot_download(self.model_id_13B)
  14. self.input = '好的'
  15. @unittest.skip('distributed gpt3 1.3B, skipped')
  16. def test_gpt3_1_3B(self):
  17. pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
  18. print(pipe(self.input))
  19. @unittest.skip('distributed gpt3 2.7B, skipped')
  20. def test_gpt3_2_7B(self):
  21. pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B)
  22. print(pipe(self.input))
  23. @unittest.skip('distributed gpt3 13B, skipped')
  24. def test_gpt3_13B(self):
  25. """ The model can be downloaded from the link on
  26. TODO: add gpt3 checkpoint link
  27. After downloading, you should have a gpt3 model structure like this:
  28. nlp_gpt3_text-generation_13B
  29. |_ config.json
  30. |_ configuration.json
  31. |_ tokenizer.json
  32. |_ model <-- an empty directory
  33. Model binaries shall be downloaded separately to populate the model directory, so that
  34. the model directory would contain the following binaries:
  35. |_ model
  36. |_ mp_rank_00_model_states.pt
  37. |_ mp_rank_01_model_states.pt
  38. |_ mp_rank_02_model_states.pt
  39. |_ mp_rank_03_model_states.pt
  40. |_ mp_rank_04_model_states.pt
  41. |_ mp_rank_05_model_states.pt
  42. |_ mp_rank_06_model_states.pt
  43. |_ mp_rank_07_model_states.pt
  44. """
  45. pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B)
  46. print(pipe(self.input))
  47. if __name__ == '__main__':
  48. unittest.main()