Browse Source

[to #42322933] Add mplug pretrained model

Add pre-trained models for mplug finetuning.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10963691
master^2
hemu.zp yingda.chen 2 years ago
parent
commit
346da3d489
4 changed files with 37 additions and 20 deletions
  1. +4
    -2
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  2. +2
    -2
      modelscope/models/multi_modal/mplug_for_all_tasks.py
  3. +13
    -2
      modelscope/trainers/multi_modal/mplug/mplug_trainer.py
  4. +18
    -14
      tests/trainers/test_finetune_mplug.py

+ 4
- 2
modelscope/models/multi_modal/mplug/modeling_mplug.py View File

@@ -1850,7 +1850,7 @@ class MPlug(PreTrainedModel):
self.config_fusion, add_pooling_layer=False)

@classmethod
def from_pretrained(cls, model_dir, load_checkpoint=True):
def from_pretrained(cls, model_dir, task=None, load_checkpoint=True):
from modelscope.utils.constant import Tasks

task_mapping = {
@@ -1861,7 +1861,9 @@ class MPlug(PreTrainedModel):
config = cls.config_class.from_yaml_file(
os.path.join(model_dir, CONFIG_NAME))
config.model_dir = model_dir
model = task_mapping[config.task](config)
if task is None:
task = config.task
model = task_mapping[task](config)
if load_checkpoint:
checkpoint_path = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)


+ 2
- 2
modelscope/models/multi_modal/mplug_for_all_tasks.py View File

@@ -20,7 +20,7 @@ __all__ = ['MPlugForAllTasks']
@MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug)
class MPlugForAllTasks(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self, model_dir: str, task=None, *args, **kwargs):
"""initialize the mplug model from the `model_dir` path.
Args:
model_dir (str): the model path.
@@ -28,7 +28,7 @@ class MPlugForAllTasks(TorchModel):

super().__init__(model_dir, *args, **kwargs)
from modelscope.models.multi_modal.mplug import MPlug
self.model = MPlug.from_pretrained(model_dir)
self.model = MPlug.from_pretrained(model_dir, task=task)
self.tokenizer = self.model.tokenizer

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:


+ 13
- 2
modelscope/trainers/multi_modal/mplug/mplug_trainer.py View File

@@ -1,18 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from collections.abc import Mapping
from typing import Optional, Union

import torch
from torch import nn

from modelscope.metainfo import Trainers
from modelscope.models import Model, TorchModel
from modelscope.outputs import OutputKeys
from modelscope.trainers import NlpEpochBasedTrainer
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.file_utils import func_receive_dict_inputs


@TRAINERS.register_module(module_name=Trainers.mplug)
class MPlugTrainer(NlpEpochBasedTrainer):
class MPlugTrainer(EpochBasedTrainer):

def __init__(self, *args, **kwargs):
self.task: Optional[str] = kwargs.pop('task', None)
super().__init__(*args, **kwargs)

def build_model(self) -> Union[nn.Module, TorchModel]:
return Model.from_pretrained(
self.model_dir, task=self.task, cfg_dict=self.cfg)

def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer


+ 18
- 14
tests/trainers/test_finetune_mplug.py View File

@@ -9,7 +9,7 @@ from modelscope.metainfo import Trainers
from modelscope.models.multi_modal import MPlugForAllTasks
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


@@ -40,11 +40,12 @@ class TestFinetuneMPlug(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_caption(self):
kwargs = dict(
model='damo/mplug_image-captioning_coco_base_en',
model='damo/mplug_backbone_base_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
task=Tasks.image_captioning)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.mplug, default_args=kwargs)
@@ -52,9 +53,9 @@ class TestFinetuneMPlug(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_caption_with_model_and_args(self):
cache_path = snapshot_download(
'damo/mplug_image-captioning_coco_base_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
cache_path = snapshot_download('damo/mplug_backbone_base_en')
model = MPlugForAllTasks.from_pretrained(
cache_path, task=Tasks.image_captioning)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
@@ -74,11 +75,12 @@ class TestFinetuneMPlug(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_vqa(self):
kwargs = dict(
model='damo/mplug_visual-question-answering_coco_large_en',
model='damo/mplug_backbone_base_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
task=Tasks.visual_question_answering)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.mplug, default_args=kwargs)
@@ -88,7 +90,8 @@ class TestFinetuneMPlug(unittest.TestCase):
def test_trainer_with_vqa_with_model_and_args(self):
cache_path = snapshot_download(
'damo/mplug_visual-question-answering_coco_large_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
model = MPlugForAllTasks.from_pretrained(
cache_path, task=Tasks.visual_question_answering)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
@@ -108,11 +111,12 @@ class TestFinetuneMPlug(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_retrieval(self):
kwargs = dict(
model='damo/mplug_image-text-retrieval_flickr30k_large_en',
model='damo/mplug_backbone_base_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)
work_dir=self.tmp_dir,
task=Tasks.image_text_retrieval)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.mplug, default_args=kwargs)
@@ -120,9 +124,9 @@ class TestFinetuneMPlug(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_retrieval_with_model_and_args(self):
cache_path = snapshot_download(
'damo/mplug_image-text-retrieval_flickr30k_large_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
cache_path = snapshot_download('damo/mplug_backbone_base_en')
model = MPlugForAllTasks.from_pretrained(
cache_path, task=Tasks.image_text_retrieval)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,


Loading…
Cancel
Save