You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_intent_trainer.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import json
  7. from modelscope.hub.snapshot_download import snapshot_download
  8. from modelscope.metainfo import Trainers
  9. from modelscope.msdatasets import MsDataset
  10. from modelscope.trainers import build_trainer
  11. from modelscope.utils.config import Config
  12. from modelscope.utils.constant import DownloadMode, ModelFile, Tasks
  13. from modelscope.utils.test_utils import test_level
  14. class TestDialogIntentTrainer(unittest.TestCase):
  15. def setUp(self):
  16. self.save_dir = tempfile.TemporaryDirectory().name
  17. if not os.path.exists(self.save_dir):
  18. os.mkdir(self.save_dir)
  19. def tearDown(self):
  20. shutil.rmtree(self.save_dir)
  21. super().tearDown()
  22. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  23. def test_trainer_with_model_and_args(self):
  24. model_id = 'damo/nlp_space_pretrained-dialog-model'
  25. data_banking = MsDataset.load('banking77')
  26. self.data_dir = data_banking._hf_ds.config_kwargs['split_config'][
  27. 'train']
  28. self.model_dir = snapshot_download(model_id)
  29. self.debugging = True
  30. kwargs = dict(
  31. model_dir=self.model_dir,
  32. cfg_name='intent_train_config.json',
  33. cfg_modify_fn=self.cfg_modify_fn)
  34. trainer = build_trainer(
  35. name=Trainers.dialog_intent_trainer, default_args=kwargs)
  36. trainer.train()
  37. def cfg_modify_fn(self, cfg):
  38. config = {
  39. 'num_intent': 77,
  40. 'BPETextField': {
  41. 'vocab_path': '',
  42. 'data_name': 'banking77',
  43. 'data_root': self.data_dir,
  44. 'understand': True,
  45. 'generation': False,
  46. 'max_len': 256
  47. },
  48. 'Dataset': {
  49. 'data_dir': self.data_dir,
  50. 'with_contrastive': False,
  51. 'trigger_role': 'user',
  52. 'trigger_data': 'banking'
  53. },
  54. 'Trainer': {
  55. 'can_norm': True,
  56. 'seed': 11,
  57. 'gpu': 1,
  58. 'save_dir': self.save_dir,
  59. 'batch_size_label': 128,
  60. 'batch_size_nolabel': 0,
  61. 'log_steps': 20
  62. },
  63. 'Model': {
  64. 'init_checkpoint': self.model_dir,
  65. 'model': 'IntentUnifiedTransformer',
  66. 'example': False,
  67. 'num_intent': 77,
  68. 'with_rdrop': True,
  69. 'num_turn_embeddings': 21,
  70. 'dropout': 0.25,
  71. 'kl_ratio': 5.0,
  72. 'embed_dropout': 0.25,
  73. 'attn_dropout': 0.25,
  74. 'ff_dropout': 0.25,
  75. 'with_pool': False,
  76. 'warmup_steps': -1
  77. }
  78. }
  79. cfg.BPETextField.vocab_path = os.path.join(self.model_dir,
  80. ModelFile.VOCAB_FILE)
  81. cfg.num_intent = 77
  82. cfg.Trainer.update(config['Trainer'])
  83. cfg.BPETextField.update(config['BPETextField'])
  84. cfg.Dataset.update(config['Dataset'])
  85. cfg.Model.update(config['Model'])
  86. if self.debugging:
  87. cfg.Trainer.save_checkpoint = False
  88. cfg.Trainer.num_epochs = 1
  89. cfg.Trainer.batch_size_label = 64
  90. return cfg
  91. if __name__ == '__main__':
  92. unittest.main()