From 2288a0fdf34e3c64e39b44d116bd7eabc9f66440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E5=97=94?= Date: Tue, 25 Oct 2022 10:18:33 +0800 Subject: [PATCH] fix all comments --- modelscope/metainfo.py | 2 +- modelscope/preprocessors/multi_modal.py | 6 ++---- modelscope/trainers/multi_modal/ofa/ofa_trainer.py | 2 +- tests/trainers/test_ofa_trainer.py | 2 +- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 48d37eb2..d3e4904e 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -282,7 +282,7 @@ class Trainers(object): # multi-modal trainers clip_multi_modal_embedding = 'clip-multi-modal-embedding' - ofa_tasks = 'ofa' + ofa = 'ofa' # cv trainers image_instance_segmentation = 'image-instance-segmentation' diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 3c4ac58a..256c5243 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data - def _compatible_with_pretrain(self, data): - # 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 - # 所以增加了这个逻辑 + def _ofa_input_compatibility_conversion(self, data): if 'image' in data and self.cfg.model.get('type', None) == 'ofa': if isinstance(data['image'], str): image = load_image(data['image']) @@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor): data = input else: data = self._build_dict(input) - data = self._compatible_with_pretrain(data) + data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index c287c182..02853925 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, get_schedule) -@TRAINERS.register_module(module_name=Trainers.ofa_tasks) +@TRAINERS.register_module(module_name=Trainers.ofa) class OFATrainer(EpochBasedTrainer): def __init__( diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 75b8cbbf..06003625 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase): split='validation[:10]'), metrics=[Metrics.BLEU], cfg_file=config_file) - trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args) + trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,