Browse Source

[to #42322933] add vqa and caption finetuning for mplug

添加 mplug 模型 caption 及 vqa 任务的 finetuning 支持
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9858028
master
hemu.zp yingda.chen 3 years ago
parent
commit
b92e2ca0a0
7 changed files with 226 additions and 148 deletions
  1. +2
    -0
      modelscope/metrics/builder.py
  2. +8
    -102
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  3. +37
    -13
      modelscope/models/multi_modal/mplug_for_all_tasks.py
  4. +2
    -1
      modelscope/models/nlp/gpt3/gpt3_for_text_generation.py
  5. +7
    -8
      modelscope/models/nlp/palm_v2/palm_for_text_generation.py
  6. +42
    -24
      modelscope/preprocessors/multi_modal.py
  7. +128
    -0
      tests/trainers/test_finetune_mplug.py

+ 2
- 0
modelscope/metrics/builder.py View File

@@ -30,6 +30,8 @@ task_default_metrics = {
Tasks.image_portrait_enhancement:
[Metrics.image_portrait_enhancement_metric],
Tasks.video_summarization: [Metrics.video_summarization_metric],
Tasks.image_captioning: [Metrics.text_gen_metric],
Tasks.visual_question_answering: [Metrics.text_gen_metric],
}




+ 8
- 102
modelscope/models/multi_modal/mplug/modeling_mplug.py View File

@@ -1969,71 +1969,6 @@ class MPlug(PreTrainedModel):
[init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim, order_index.to(x.device))

def rank_answer(self, question_states, question_atts, answer_ids,
answer_atts, k):

num_ques = question_states.size(0)
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token

start_output = self.text_decoder(
start_ids,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
return_dict=True,
reduction='none')
logits = start_output.logits[:, 0, :] # first token's logit

# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:, 1]
prob_first_token = F.softmax(
logits, dim=1).index_select(
dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(k, dim=1)

# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids, dim=0)
input_atts = torch.cat(input_atts, dim=0)

targets_ids = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100)

# repeat encoder's output for top-k answers
question_states = self._tile(question_states, 0, k)
question_atts = self._tile(question_atts, 0, k)

output = self.text_decoder(
input_ids,
attention_mask=input_atts,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
labels=targets_ids,
return_dict=True,
reduction='none')

answer_loss = output.loss
answer_loss = answer_loss.view(input_ids.size(0), -1)

# topk_prob: first token probability
topk_probs = topk_probs.view(-1, 1)
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)

# re-calculate log probabilities for the answer sequences using chain rule
log_probs_sum = log_probs.sum(1)
log_probs_sum = log_probs_sum.view(num_ques, k)

topk_probs = F.softmax(log_probs_sum, dim=-1)
# get top-k after re-ranking
topk_probs, rerank_id = topk_probs.topk(k, dim=1)
topk_ids = torch.gather(topk_ids, 1, rerank_id)

return topk_ids, topk_probs


class MPlugForVisualQuestionAnswering(MPlug):

@@ -2111,6 +2046,8 @@ class MPlugForVisualQuestionAnswering(MPlug):
merge_text_attention = torch.cat(
[image_atts, question.attention_mask], 1)

if k is None:
k = [1] * question_output.shape[0]
question_states = []
question_atts = []
for b, n in enumerate(k):
@@ -2177,6 +2114,8 @@ class MPlugForVisualQuestionAnswering(MPlug):
return_dict=True,
reduction='none',
)
if weights is None:
weights = 1
loss = weights * answer_output.loss
loss = loss.sum() / image.size(0)

@@ -2262,50 +2201,17 @@ class MPLUGForImageCaption(MPlug):
if train:
answer_targets = answer.input_ids.masked_fill(
answer.input_ids == self.tokenizer.pad_token_id, -100)
text_output = self.text_encoder(
question.input_ids,
attention_mask=question.attention_mask,
return_dict=True)
text_embeds = text_output.last_hidden_state
fusion_output = self.fusion_encoder(
encoder_embeds=text_embeds,
attention_mask=question.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=False)

image_output, question_output = fusion_output

question_output = torch.cat([image_output, question_output], 1)
merge_text_attention = torch.cat(
[image_atts, question.attention_mask], 1)

answer_output = self.text_decoder(
answer.input_ids,
attention_mask=answer.attention_mask,
encoder_hidden_states=question_output,
encoder_attention_mask=merge_text_attention,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
labels=answer_targets,
return_dict=True,
reduction='none')
loss = answer_output.loss

return loss
else:
text_output = self.text_encoder(
question.input_ids,
attention_mask=question.attention_mask,
return_dict=True)
text_embeds = text_output.last_hidden_state
fusion_output = self.fusion_encoder(
encoder_embeds=text_embeds,
attention_mask=question.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=False)
image_output, question_output = fusion_output
question_output = torch.cat([image_output, question_output], 1)
merge_text_attention = torch.cat(
[image_atts, question.attention_mask], 1)
topk_ids, topk_probs = self.generation(question_output,
merge_text_attention)
topk_ids, topk_probs = self.generation(image_embeds, image_atts)
return topk_ids, topk_probs

+ 37
- 13
modelscope/models/multi_modal/mplug_for_all_tasks.py View File

@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, List

from modelscope.metainfo import Models
from modelscope.models import TorchModel
@@ -25,12 +25,6 @@ class MPlugForAllTasks(TorchModel):
self.model = MPlug.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

@@ -45,13 +39,43 @@ class MPlugForAllTasks(TorchModel):
}
"""

topk_ids, _ = self.model(**input)
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))

pred_string = self.tokenizer.decode(topk_ids[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip()
return pred_string
if not self.training and 'answer_input_ids' not in input:
topk_ids, _ = self.model(**input)
pred_string: str = self.tokenizer.decode(topk_ids[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip()
return pred_string
else:
import addict
question = addict.Dict(
input_ids=input['question_input_ids'],
attention_mask=input['question_attention_mask'])
answer = addict.Dict(
input_ids=input['answer_input_ids'],
attention_mask=input['answer_attention_mask'])
output = self.model(
input['image'], question, answer, train=self.training)
if self.training:
return {'loss': output}
topk_ids, _ = output
preds: List[str] = [
self.tokenizer.decode(batch[0]) for batch in topk_ids
]
for i in range(len(preds)):
for _old, _new in replace_tokens_bert:
preds[i] = preds[i].replace(_old, _new)
preds[i] = preds[i].strip()
tgts: List[str] = [
self.tokenizer.decode(batch)
for batch in input['answer_input_ids'].cpu().numpy().tolist()
]
for i in range(len(tgts)):
for _old, _new in replace_tokens_bert:
tgts[i] = tgts[i].replace(_old, _new)
preds[i] = preds[i].strip()
return {'preds': preds, 'tgts': tgts}

+ 2
- 1
modelscope/models/nlp/gpt3/gpt3_for_text_generation.py View File

@@ -60,5 +60,6 @@ class GPT3ForTextGeneration(TorchModel):
sample_output = self.model.generate(**gen_params)
return {
OutputKeys.TEXT:
self.tokenizer.decode(sample_output[0], skip_special_tokens=True)
self.tokenizer.decode(sample_output[0],
skip_special_tokens=True).replace(' ', '')
}

+ 7
- 8
modelscope/models/nlp/palm_v2/palm_for_text_generation.py View File

@@ -29,20 +29,19 @@ class PalmForTextGeneration(TorchModel):
self.generator = Translator(self.model)

def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]',
''),
(r' +', ' '), ('[SEP]', ''), ('[unused2]', ''),
('[CLS]', ''), ('[UNK]', ''), (' ', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '),
('<pad>', ''), ('<s>', ''), ('</s>', ''),
('<unk>', ' '), ('<q>', '. '))

replace_tokens = replace_tokens_roberta \
if self.model.config.encoder == 'roberta' else replace_tokens_bert
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list]
for _old, _new in replace_tokens_bert:
for _old, _new in replace_tokens:
strings = [s.replace(_old, _new) for s in strings]
for _old, _new in replace_tokens_roberta:
strings = [s.replace(_old, _new) for s in strings]
for s in strings:
s.strip()
return strings

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:


+ 42
- 24
modelscope/preprocessors/multi_modal.py View File

@@ -9,7 +9,7 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
from modelscope.pipelines.base import Input
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks
from .base import Preprocessor
from .builder import PREPROCESSORS
from .ofa import * # noqa
@@ -91,9 +91,16 @@ class OfaPreprocessor(Preprocessor):
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor)
class MPlugPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self,
model_dir: str,
mode: str = ModeKeys.INFERENCE,
tokenizer_max_length: int = 25,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.model_dir = model_dir
self.mode = mode
self.tokenizer_max_length = tokenizer_max_length

self._tokenizer = None
self._patch_resize_transform = None
@@ -128,40 +135,51 @@ class MPlugPreprocessor(Preprocessor):

def __call__(self, *args, **kwargs):
call_mapping = {
Tasks.visual_question_answering: self.vqa_call,
Tasks.image_captioning: self.caption_call
Tasks.visual_question_answering: self.image_text_call,
Tasks.image_captioning: self.image_text_call,
}

self.cfg = Config.from_file(
osp.join(self.model_dir, ModelFile.CONFIGURATION))
return call_mapping[self.cfg.task](*args, **kwargs)

def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]:
image: Image.Image = data[0] if isinstance(data,
tuple) else data['image']
question: str = data[1] if isinstance(data,
tuple) else data['question']
image = image.convert('RGB')
image = self.patch_resize_transform(image)
image = torch.stack([image], dim=0)
question = self.tokenizer([question.lower()],
padding='longest',
return_tensors='pt')

return {'image': image, 'question': question, 'train': False}

def caption_call(
def image_text_call(
self, data: Union[Image.Image, tuple,
Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(data, Image.Image):
if isinstance(data, (Image.Image, str)):
image = data
elif isinstance(data, tuple):
image = data[0]
else:
image = data['image']
if isinstance(image, str):
image = Image.open(image)
question = '' if self.cfg.task != Tasks.visual_question_answering \
else data[1 if isinstance(data, tuple) else 'question']
image = image.convert('RGB')
image = self.patch_resize_transform(image)
image = torch.stack([image], dim=0)
question = self.tokenizer('', return_tensors='pt')

return {'image': image, 'question': question, 'train': False}
question = self.tokenizer(
question.lower(),
padding='max_length',
truncation=True,
max_length=self.tokenizer_max_length,
return_tensors='pt')

if self.mode == ModeKeys.INFERENCE:
image = torch.stack([image], dim=0)
return {'image': image, 'question': question, 'train': False}
else:
answer = data['answer']
answer = self.tokenizer(
answer,
padding='max_length',
truncation=True,
max_length=self.tokenizer_max_length,
return_tensors='pt')
return {
'image': image,
'question_input_ids': question.input_ids.squeeze(),
'question_attention_mask': question.attention_mask.squeeze(),
'answer_input_ids': answer.input_ids.squeeze(),
'answer_attention_mask': answer.attention_mask.squeeze(),
}

+ 128
- 0
tests/trainers/test_finetune_mplug.py View File

@@ -0,0 +1,128 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.multi_modal import MPlugForAllTasks
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestFinetuneMPlug(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

datadict = MsDataset.load('coco_captions_small_slice')
self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map(
lambda _: {
'question': 'what the picture describes?'
}).rename_column('image:FILE',
'image').rename_column('answer:Value', 'answer'))
self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map(
lambda _: {
'question': 'what the picture describes?'
}).rename_column('image:FILE',
'image').rename_column('answer:Value', 'answer'))

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_caption(self):

kwargs = dict(
model='damo/mplug_image-captioning_coco_base_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_caption_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(
'damo/mplug_image-captioning_coco_base_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=2,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_vqa(self):

kwargs = dict(
model='damo/mplug_visual-question-answering_coco_large_en',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_vqa_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

cache_path = snapshot_download(
'damo/mplug_visual-question-answering_coco_large_en')
model = MPlugForAllTasks.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
max_epochs=2,
work_dir=self.tmp_dir)

trainer: EpochBasedTrainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save