diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 48d37eb2..d3e4904e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -282,7 +282,7 @@ class Trainers(object): # multi-modal trainers clip_multi_modal_embedding = 'clip-multi-modal-embedding' - ofa_tasks = 'ofa' + ofa = 'ofa' # cv trainers image_instance_segmentation = 'image-instance-segmentation' diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 3c4ac58a..256c5243 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data - def _compatible_with_pretrain(self, data): - # 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 - # 所以增加了这个逻辑 + def _ofa_input_compatibility_conversion(self, data): if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if isinstance(data['image'], str): image = load_image(data['image']) @@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor): data = input else: data = self._build_dict(input) - data = self._compatible_with_pretrain(data) + data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index c287c182..02853925 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, get_schedule) -@TRAINERS.register_module(module_name=Trainers.ofa_tasks) +@TRAINERS.register_module(module_name=Trainers.ofa) class OFATrainer(EpochBasedTrainer): def __init__( diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 75b8cbbf..06003625 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase): split='validation[:10]'), metrics=[Metrics.BLEU], cfg_file=config_file) - trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args) + trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,