Browse Source

[to #42322933] reduce the GPU usage of dialog trianer

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10955485
master^2
ly119399 wenmeng.zwm 3 years ago
parent
commit
2f17daa23f
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      tests/trainers/test_dialog_modeling_trainer.py

+ 7
- 3
tests/trainers/test_dialog_modeling_trainer.py View File

@@ -17,7 +17,7 @@ class TestDialogModelingTrainer(unittest.TestCase):
model_id = 'damo/nlp_space_pretrained-dialog-model'
output_dir = './dialog_fintune_result'

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
# download data set
data_multiwoz = MsDataset.load(
@@ -33,13 +33,13 @@ class TestDialogModelingTrainer(unittest.TestCase):
def cfg_modify_fn(cfg):
config = {
'seed': 10,
'gpu': 4,
'gpu': 1,
'use_data_distributed': False,
'valid_metric_name': '-loss',
'num_epochs': 60,
'save_dir': self.output_dir,
'token_loss': True,
'batch_size': 32,
'batch_size': 4,
'log_steps': 10,
'valid_steps': 0,
'save_checkpoint': True,
@@ -71,3 +71,7 @@ class TestDialogModelingTrainer(unittest.TestCase):
assert os.path.exists(checkpoint_path)
trainer.evaluate(checkpoint_path=checkpoint_path)
"""


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save