diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index a678865a..0516e569 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -5,6 +5,7 @@ import unittest import json +from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile @@ -95,7 +96,7 @@ class TestOfaTrainer(unittest.TestCase): split='test[:20]', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), cfg_file=config_file) - trainer = build_trainer(name='ofa', default_args=args) + trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() self.assertIn(