You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_finetune_text_ranking.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
  7. import torch
  8. from transformers.tokenization_utils_base import PreTrainedTokenizerBase
  9. from modelscope.metainfo import Trainers
  10. from modelscope.models import Model
  11. from modelscope.msdatasets import MsDataset
  12. from modelscope.pipelines import pipeline
  13. from modelscope.trainers import build_trainer
  14. from modelscope.utils.constant import ModelFile, Tasks
  15. from modelscope.utils.test_utils import test_level
  16. class TestFinetuneSequenceClassification(unittest.TestCase):
  17. inputs = {
  18. 'source_sentence': ["how long it take to get a master's degree"],
  19. 'sentences_to_compare': [
  20. "On average, students take about 18 to 24 months to complete a master's degree.",
  21. 'On the other hand, some students prefer to go at a slower pace and choose to take '
  22. 'several years to complete their studies.',
  23. 'It can take anywhere from two semesters'
  24. ]
  25. }
  26. def setUp(self):
  27. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  28. self.tmp_dir = tempfile.TemporaryDirectory().name
  29. if not os.path.exists(self.tmp_dir):
  30. os.makedirs(self.tmp_dir)
  31. def tearDown(self):
  32. shutil.rmtree(self.tmp_dir)
  33. super().tearDown()
  34. def finetune(self,
  35. model_id,
  36. train_dataset,
  37. eval_dataset,
  38. name=Trainers.nlp_text_ranking_trainer,
  39. cfg_modify_fn=None,
  40. **kwargs):
  41. kwargs = dict(
  42. model=model_id,
  43. train_dataset=train_dataset,
  44. eval_dataset=eval_dataset,
  45. work_dir=self.tmp_dir,
  46. cfg_modify_fn=cfg_modify_fn,
  47. **kwargs)
  48. os.environ['LOCAL_RANK'] = '0'
  49. trainer = build_trainer(name=name, default_args=kwargs)
  50. trainer.train()
  51. results_files = os.listdir(self.tmp_dir)
  52. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  53. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  54. def test_finetune_msmarco(self):
  55. def cfg_modify_fn(cfg):
  56. neg_sample = 4
  57. cfg.task = 'text-ranking'
  58. cfg['preprocessor'] = {'type': 'text-ranking'}
  59. cfg.train.optimizer.lr = 2e-5
  60. cfg['dataset'] = {
  61. 'train': {
  62. 'type': 'bert',
  63. 'query_sequence': 'query',
  64. 'pos_sequence': 'positive_passages',
  65. 'neg_sequence': 'negative_passages',
  66. 'text_fileds': ['title', 'text'],
  67. 'qid_field': 'query_id',
  68. 'neg_sample': neg_sample
  69. },
  70. 'val': {
  71. 'type': 'bert',
  72. 'query_sequence': 'query',
  73. 'pos_sequence': 'positive_passages',
  74. 'neg_sequence': 'negative_passages',
  75. 'text_fileds': ['title', 'text'],
  76. 'qid_field': 'query_id'
  77. },
  78. }
  79. cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
  80. cfg.train.max_epochs = 1
  81. cfg.train.train_batch_size = 4
  82. cfg.train.lr_scheduler = {
  83. 'type': 'LinearLR',
  84. 'start_factor': 1.0,
  85. 'end_factor': 0.0,
  86. 'options': {
  87. 'by_epoch': False
  88. }
  89. }
  90. cfg.model['neg_sample'] = 4
  91. cfg.train.hooks = [{
  92. 'type': 'CheckpointHook',
  93. 'interval': 1
  94. }, {
  95. 'type': 'TextLoggerHook',
  96. 'interval': 1
  97. }, {
  98. 'type': 'IterTimerHook'
  99. }, {
  100. 'type': 'EvaluationHook',
  101. 'by_epoch': False,
  102. 'interval': 15
  103. }]
  104. return cfg
  105. # load dataset
  106. ds = MsDataset.load('passage-ranking-demo', 'zyznull')
  107. train_ds = ds['train'].to_hf_dataset()
  108. dev_ds = ds['dev'].to_hf_dataset()
  109. model_id = 'damo/nlp_corom_passage-ranking_english-base'
  110. self.finetune(
  111. model_id=model_id,
  112. train_dataset=train_ds,
  113. eval_dataset=dev_ds,
  114. cfg_modify_fn=cfg_modify_fn)
  115. output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
  116. self.pipeline_text_ranking(output_dir)
  117. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  118. def test_finetune_dureader(self):
  119. def cfg_modify_fn(cfg):
  120. cfg.task = 'text-ranking'
  121. cfg['preprocessor'] = {'type': 'text-ranking'}
  122. cfg.train.optimizer.lr = 2e-5
  123. cfg['dataset'] = {
  124. 'train': {
  125. 'type': 'bert',
  126. 'query_sequence': 'query',
  127. 'pos_sequence': 'positive_passages',
  128. 'neg_sequence': 'negative_passages',
  129. 'text_fileds': ['text'],
  130. 'qid_field': 'query_id'
  131. },
  132. 'val': {
  133. 'type': 'bert',
  134. 'query_sequence': 'query',
  135. 'pos_sequence': 'positive_passages',
  136. 'neg_sequence': 'negative_passages',
  137. 'text_fileds': ['text'],
  138. 'qid_field': 'query_id'
  139. },
  140. }
  141. cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
  142. cfg.train.max_epochs = 1
  143. cfg.train.train_batch_size = 4
  144. cfg.train.lr_scheduler = {
  145. 'type': 'LinearLR',
  146. 'start_factor': 1.0,
  147. 'end_factor': 0.0,
  148. 'options': {
  149. 'by_epoch': False
  150. }
  151. }
  152. cfg.train.hooks = [{
  153. 'type': 'CheckpointHook',
  154. 'interval': 1
  155. }, {
  156. 'type': 'TextLoggerHook',
  157. 'interval': 1
  158. }, {
  159. 'type': 'IterTimerHook'
  160. }, {
  161. 'type': 'EvaluationHook',
  162. 'by_epoch': False,
  163. 'interval': 5000
  164. }]
  165. return cfg
  166. # load dataset
  167. ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull')
  168. train_ds = ds['train'].to_hf_dataset().shard(1000, index=0)
  169. dev_ds = ds['dev'].to_hf_dataset()
  170. model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
  171. self.finetune(
  172. model_id=model_id,
  173. train_dataset=train_ds,
  174. eval_dataset=dev_ds,
  175. cfg_modify_fn=cfg_modify_fn)
  176. def pipeline_text_ranking(self, model_dir):
  177. model = Model.from_pretrained(model_dir)
  178. pipeline_ins = pipeline(task=Tasks.text_ranking, model=model)
  179. print(pipeline_ins(input=self.inputs))
  180. if __name__ == '__main__':
  181. unittest.main()