diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index ec491f1d..1d003f5c 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1850,7 +1850,7 @@ class MPlug(PreTrainedModel): self.config_fusion, add_pooling_layer=False) @classmethod - def from_pretrained(cls, model_dir, load_checkpoint=True): + def from_pretrained(cls, model_dir, task=None, load_checkpoint=True): from modelscope.utils.constant import Tasks task_mapping = { @@ -1861,7 +1861,9 @@ class MPlug(PreTrainedModel): config = cls.config_class.from_yaml_file( os.path.join(model_dir, CONFIG_NAME)) config.model_dir = model_dir - model = task_mapping[config.task](config) + if task is None: + task = config.task + model = task_mapping[task](config) if load_checkpoint: checkpoint_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index 7de8d291..4d2a6ac2 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -20,7 +20,7 @@ __all__ = ['MPlugForAllTasks'] @MODELS.register_module(Tasks.image_text_retrieval, module_name=Models.mplug) class MPlugForAllTasks(TorchModel): - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, model_dir: str, task=None, *args, **kwargs): """initialize the mplug model from the `model_dir` path. Args: model_dir (str): the model path. @@ -28,7 +28,7 @@ class MPlugForAllTasks(TorchModel): super().__init__(model_dir, *args, **kwargs) from modelscope.models.multi_modal.mplug import MPlug - self.model = MPlug.from_pretrained(model_dir) + self.model = MPlug.from_pretrained(model_dir, task=task) self.tokenizer = self.model.tokenizer def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: diff --git a/modelscope/trainers/multi_modal/mplug/mplug_trainer.py b/modelscope/trainers/multi_modal/mplug/mplug_trainer.py index def66220..fb456719 100644 --- a/modelscope/trainers/multi_modal/mplug/mplug_trainer.py +++ b/modelscope/trainers/multi_modal/mplug/mplug_trainer.py @@ -1,18 +1,29 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections.abc import Mapping +from typing import Optional, Union import torch +from torch import nn from modelscope.metainfo import Trainers +from modelscope.models import Model, TorchModel from modelscope.outputs import OutputKeys -from modelscope.trainers import NlpEpochBasedTrainer +from modelscope.trainers import EpochBasedTrainer from modelscope.trainers.builder import TRAINERS from modelscope.utils.file_utils import func_receive_dict_inputs @TRAINERS.register_module(module_name=Trainers.mplug) -class MPlugTrainer(NlpEpochBasedTrainer): +class MPlugTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + self.task: Optional[str] = kwargs.pop('task', None) + super().__init__(*args, **kwargs) + + def build_model(self) -> Union[nn.Module, TorchModel]: + return Model.from_pretrained( + self.model_dir, task=self.task, cfg_dict=self.cfg) def _decode(self, tokens): tokenizer = self.eval_preprocessor.tokenizer diff --git a/tests/trainers/test_finetune_mplug.py b/tests/trainers/test_finetune_mplug.py index 46664114..c64e1285 100644 --- a/tests/trainers/test_finetune_mplug.py +++ b/tests/trainers/test_finetune_mplug.py @@ -9,7 +9,7 @@ from modelscope.metainfo import Trainers from modelscope.models.multi_modal import MPlugForAllTasks from modelscope.msdatasets import MsDataset from modelscope.trainers import EpochBasedTrainer, build_trainer -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -40,11 +40,12 @@ class TestFinetuneMPlug(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_caption(self): kwargs = dict( - model='damo/mplug_image-captioning_coco_base_en', + model='damo/mplug_backbone_base_en', train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + task=Tasks.image_captioning) trainer: EpochBasedTrainer = build_trainer( name=Trainers.mplug, default_args=kwargs) @@ -52,9 +53,9 @@ class TestFinetuneMPlug(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_caption_with_model_and_args(self): - cache_path = snapshot_download( - 'damo/mplug_image-captioning_coco_base_en') - model = MPlugForAllTasks.from_pretrained(cache_path) + cache_path = snapshot_download('damo/mplug_backbone_base_en') + model = MPlugForAllTasks.from_pretrained( + cache_path, task=Tasks.image_captioning) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model, @@ -74,11 +75,12 @@ class TestFinetuneMPlug(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_vqa(self): kwargs = dict( - model='damo/mplug_visual-question-answering_coco_large_en', + model='damo/mplug_backbone_base_en', train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + task=Tasks.visual_question_answering) trainer: EpochBasedTrainer = build_trainer( name=Trainers.mplug, default_args=kwargs) @@ -88,7 +90,8 @@ class TestFinetuneMPlug(unittest.TestCase): def test_trainer_with_vqa_with_model_and_args(self): cache_path = snapshot_download( 'damo/mplug_visual-question-answering_coco_large_en') - model = MPlugForAllTasks.from_pretrained(cache_path) + model = MPlugForAllTasks.from_pretrained( + cache_path, task=Tasks.visual_question_answering) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model, @@ -108,11 +111,12 @@ class TestFinetuneMPlug(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_retrieval(self): kwargs = dict( - model='damo/mplug_image-text-retrieval_flickr30k_large_en', + model='damo/mplug_backbone_base_en', train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + task=Tasks.image_text_retrieval) trainer: EpochBasedTrainer = build_trainer( name=Trainers.mplug, default_args=kwargs) @@ -120,9 +124,9 @@ class TestFinetuneMPlug(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_retrieval_with_model_and_args(self): - cache_path = snapshot_download( - 'damo/mplug_image-text-retrieval_flickr30k_large_en') - model = MPlugForAllTasks.from_pretrained(cache_path) + cache_path = snapshot_download('damo/mplug_backbone_base_en') + model = MPlugForAllTasks.from_pretrained( + cache_path, task=Tasks.image_text_retrieval) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model,