Browse Source

fix all comments

master
行嗔 2 years ago
parent
commit
2288a0fdf3
4 changed files with 5 additions and 7 deletions
  1. +1
    -1
      modelscope/metainfo.py
  2. +2
    -4
      modelscope/preprocessors/multi_modal.py
  3. +1
    -1
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  4. +1
    -1
      tests/trainers/test_ofa_trainer.py

+ 1
- 1
modelscope/metainfo.py View File

@@ -282,7 +282,7 @@ class Trainers(object):

# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa_tasks = 'ofa'
ofa = 'ofa'

# cv trainers
image_instance_segmentation = 'image-instance-segmentation'


+ 2
- 4
modelscope/preprocessors/multi_modal.py View File

@@ -74,9 +74,7 @@ class OfaPreprocessor(Preprocessor):
data[key] = item
return data

def _compatible_with_pretrain(self, data):
# 预训练的时候使用的image都是经过pil转换的,PIL save的时候一般会进行有损压缩,为了保证和预训练一致
# 所以增加了这个逻辑
def _ofa_input_compatibility_conversion(self, data):
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
if isinstance(data['image'], str):
image = load_image(data['image'])
@@ -95,7 +93,7 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._compatible_with_pretrain(data)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():


+ 1
- 1
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -27,7 +27,7 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
get_schedule)


@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
@TRAINERS.register_module(module_name=Trainers.ofa)
class OFATrainer(EpochBasedTrainer):

def __init__(


+ 1
- 1
tests/trainers/test_ofa_trainer.py View File

@@ -93,7 +93,7 @@ class TestOfaTrainer(unittest.TestCase):
split='validation[:10]'),
metrics=[Metrics.BLEU],
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,


Loading…
Cancel
Save