diff --git a/tests/trainers/test_dialog_modeling_trainer.py b/tests/trainers/test_dialog_modeling_trainer.py index 2937ad7e..9d9fd11b 100644 --- a/tests/trainers/test_dialog_modeling_trainer.py +++ b/tests/trainers/test_dialog_modeling_trainer.py @@ -17,7 +17,7 @@ class TestDialogModelingTrainer(unittest.TestCase): model_id = 'damo/nlp_space_pretrained-dialog-model' output_dir = './dialog_fintune_result' - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_model_and_args(self): # download data set data_multiwoz = MsDataset.load( @@ -33,13 +33,13 @@ class TestDialogModelingTrainer(unittest.TestCase): def cfg_modify_fn(cfg): config = { 'seed': 10, - 'gpu': 4, + 'gpu': 1, 'use_data_distributed': False, 'valid_metric_name': '-loss', 'num_epochs': 60, 'save_dir': self.output_dir, 'token_loss': True, - 'batch_size': 32, + 'batch_size': 4, 'log_steps': 10, 'valid_steps': 0, 'save_checkpoint': True, @@ -71,3 +71,7 @@ class TestDialogModelingTrainer(unittest.TestCase): assert os.path.exists(checkpoint_path) trainer.evaluate(checkpoint_path=checkpoint_path) """ + + +if __name__ == '__main__': + unittest.main()