# Copyright (c) Alibaba, Inc. and its affiliates. import os import unittest import torch from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors, Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.test_utils import test_level class TestDialogModelingTrainer(unittest.TestCase): model_id = 'damo/nlp_space_pretrained-dialog-model' output_dir = './dialog_fintune_result' @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer_with_model_and_args(self): # download data set data_multiwoz = MsDataset.load( 'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) data_dir = os.path.join( data_multiwoz._hf_ds.config_kwargs['split_config']['train'], 'data') # download model model_dir = snapshot_download(self.model_id) # dialog finetune config def cfg_modify_fn(cfg): config = { 'seed': 10, 'gpu': 4, 'use_data_distributed': False, 'valid_metric_name': '-loss', 'num_epochs': 60, 'save_dir': self.output_dir, 'token_loss': True, 'batch_size': 32, 'log_steps': 10, 'valid_steps': 0, 'save_checkpoint': True, 'save_summary': False, 'shuffle': True, 'sort_pool_size': 0 } cfg.Trainer = config cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 return cfg # trainer config kwargs = dict( model_dir=model_dir, cfg_name='gen_train_config.json', data_dir=data_dir, cfg_modify_fn=cfg_modify_fn) trainer = build_trainer( name=Trainers.dialog_modeling_trainer, default_args=kwargs) trainer.train() checkpoint_path = os.path.join(self.output_dir, ModelFile.TORCH_MODEL_BIN_FILE) assert os.path.exists(checkpoint_path) trainer.evaluate(checkpoint_path=checkpoint_path)