@@ -282,7 +282,7 @@ class Trainers(object): | |||||
# multi-modal trainers | # multi-modal trainers | ||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
ofa_tasks = 'ofa' | |||||
ofa = 'ofa' | |||||
# cv trainers | # cv trainers | ||||
image_instance_segmentation = 'image-instance-segmentation' | image_instance_segmentation = 'image-instance-segmentation' | ||||
@@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor): | |||||
data[key] = item | data[key] = item | ||||
return data | return data | ||||
def _compatible_with_pretrain(self, data): | |||||
# 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致 | |||||
# 所以增加了这个逻辑 | |||||
def _ofa_input_compatibility_conversion(self, data): | |||||
if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | ||||
if isinstance(data['image'], str): | if isinstance(data['image'], str): | ||||
image = load_image(data['image']) | image = load_image(data['image']) | ||||
@@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor): | |||||
data = input | data = input | ||||
else: | else: | ||||
data = self._build_dict(input) | data = self._build_dict(input) | ||||
data = self._compatible_with_pretrain(data) | |||||
data = self._ofa_input_compatibility_conversion(data) | |||||
sample = self.preprocess(data) | sample = self.preprocess(data) | ||||
str_data = dict() | str_data = dict() | ||||
for k, v in data.items(): | for k, v in data.items(): | ||||
@@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
get_schedule) | get_schedule) | ||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||||
@TRAINERS.register_module(module_name=Trainers.ofa) | |||||
class OFATrainer(EpochBasedTrainer): | class OFATrainer(EpochBasedTrainer): | ||||
def __init__( | def __init__( | ||||
@@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
split='validation[:10]'), | split='validation[:10]'), | ||||
metrics=[Metrics.BLEU], | metrics=[Metrics.BLEU], | ||||
cfg_file=config_file) | cfg_file=config_file) | ||||
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args) | |||||
trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||||
trainer.train() | trainer.train() | ||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | ||||