|
|
|
@@ -17,7 +17,7 @@ class TestDialogModelingTrainer(unittest.TestCase): |
|
|
|
model_id = 'damo/nlp_space_pretrained-dialog-model' |
|
|
|
output_dir = './dialog_fintune_result' |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer_with_model_and_args(self): |
|
|
|
# download data set |
|
|
|
data_multiwoz = MsDataset.load( |
|
|
|
@@ -33,13 +33,13 @@ class TestDialogModelingTrainer(unittest.TestCase): |
|
|
|
def cfg_modify_fn(cfg): |
|
|
|
config = { |
|
|
|
'seed': 10, |
|
|
|
'gpu': 4, |
|
|
|
'gpu': 1, |
|
|
|
'use_data_distributed': False, |
|
|
|
'valid_metric_name': '-loss', |
|
|
|
'num_epochs': 60, |
|
|
|
'save_dir': self.output_dir, |
|
|
|
'token_loss': True, |
|
|
|
'batch_size': 32, |
|
|
|
'batch_size': 4, |
|
|
|
'log_steps': 10, |
|
|
|
'valid_steps': 0, |
|
|
|
'save_checkpoint': True, |
|
|
|
@@ -71,3 +71,7 @@ class TestDialogModelingTrainer(unittest.TestCase): |
|
|
|
assert os.path.exists(checkpoint_path) |
|
|
|
trainer.evaluate(checkpoint_path=checkpoint_path) |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
unittest.main() |