You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_finetune_gpt3.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.metainfo import Trainers
  7. from modelscope.msdatasets import MsDataset
  8. from modelscope.trainers import build_trainer
  9. class TestFinetuneTextGeneration(unittest.TestCase):
  10. def setUp(self):
  11. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  12. self.tmp_dir = tempfile.TemporaryDirectory().name
  13. if not os.path.exists(self.tmp_dir):
  14. os.makedirs(self.tmp_dir)
  15. def tearDown(self):
  16. shutil.rmtree(self.tmp_dir)
  17. super().tearDown()
  18. @unittest.skip
  19. def test_finetune_poetry(self):
  20. dataset_dict = MsDataset.load('chinese-poetry-collection')
  21. train_dataset = dataset_dict['train'].to_hf_dataset().rename_columns(
  22. {'text1': 'src_txt'})
  23. eval_dataset = dataset_dict['test'].to_hf_dataset().rename_columns(
  24. {'text1': 'src_txt'})
  25. max_epochs = 10
  26. tmp_dir = './gpt3_poetry'
  27. num_warmup_steps = 100
  28. def noam_lambda(current_step: int):
  29. current_step += 1
  30. return min(current_step**(-0.5),
  31. current_step * num_warmup_steps**(-1.5))
  32. def cfg_modify_fn(cfg):
  33. cfg.train.lr_scheduler = {
  34. 'type': 'LambdaLR',
  35. 'lr_lambda': noam_lambda,
  36. 'options': {
  37. 'by_epoch': False
  38. }
  39. }
  40. cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
  41. cfg.train.dataloader = {
  42. 'batch_size_per_gpu': 16,
  43. 'workers_per_gpu': 1
  44. }
  45. return cfg
  46. kwargs = dict(
  47. model='damo/nlp_gpt3_text-generation_1.3B',
  48. train_dataset=train_dataset,
  49. eval_dataset=eval_dataset,
  50. max_epochs=max_epochs,
  51. work_dir=tmp_dir,
  52. cfg_modify_fn=cfg_modify_fn)
  53. # Construct trainer and train
  54. trainer = build_trainer(
  55. name=Trainers.gpt3_trainer, default_args=kwargs)
  56. trainer.train()
  57. @unittest.skip
  58. def test_finetune_dureader(self):
  59. # DuReader_robust-QG is an example data set,
  60. # users can also use their own data set for training
  61. dataset_dict = MsDataset.load('DuReader_robust-QG')
  62. train_dataset = dataset_dict['train'].to_hf_dataset() \
  63. .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
  64. .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})
  65. eval_dataset = dataset_dict['validation'].to_hf_dataset() \
  66. .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
  67. .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})
  68. max_epochs = 10
  69. tmp_dir = './gpt3_dureader'
  70. num_warmup_steps = 200
  71. def noam_lambda(current_step: int):
  72. current_step += 1
  73. return min(current_step**(-0.5),
  74. current_step * num_warmup_steps**(-1.5))
  75. def cfg_modify_fn(cfg):
  76. cfg.train.lr_scheduler = {
  77. 'type': 'LambdaLR',
  78. 'lr_lambda': noam_lambda,
  79. 'options': {
  80. 'by_epoch': False
  81. }
  82. }
  83. cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
  84. cfg.train.dataloader = {
  85. 'batch_size_per_gpu': 16,
  86. 'workers_per_gpu': 1
  87. }
  88. cfg.train.hooks.append({
  89. 'type': 'EvaluationHook',
  90. 'by_epoch': True,
  91. 'interval': 1
  92. })
  93. cfg.preprocessor.sequence_length = 512
  94. cfg.model.checkpoint_model_parallel_size = 1
  95. return cfg
  96. kwargs = dict(
  97. model='damo/nlp_gpt3_text-generation_1.3B',
  98. train_dataset=train_dataset,
  99. eval_dataset=eval_dataset,
  100. max_epochs=max_epochs,
  101. work_dir=tmp_dir,
  102. cfg_modify_fn=cfg_modify_fn)
  103. # Construct trainer and train
  104. trainer = build_trainer(
  105. name=Trainers.gpt3_trainer, default_args=kwargs)
  106. trainer.train()
  107. if __name__ == '__main__':
  108. unittest.main()