添加 mplug 模型 caption 及 vqa 任务的 finetuning 支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9858028master
@@ -30,6 +30,8 @@ task_default_metrics = { | |||
Tasks.image_portrait_enhancement: | |||
[Metrics.image_portrait_enhancement_metric], | |||
Tasks.video_summarization: [Metrics.video_summarization_metric], | |||
Tasks.image_captioning: [Metrics.text_gen_metric], | |||
Tasks.visual_question_answering: [Metrics.text_gen_metric], | |||
} | |||
@@ -1969,71 +1969,6 @@ class MPlug(PreTrainedModel): | |||
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |||
return torch.index_select(x, dim, order_index.to(x.device)) | |||
def rank_answer(self, question_states, question_atts, answer_ids, | |||
answer_atts, k): | |||
num_ques = question_states.size(0) | |||
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token | |||
start_output = self.text_decoder( | |||
start_ids, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
return_dict=True, | |||
reduction='none') | |||
logits = start_output.logits[:, 0, :] # first token's logit | |||
# topk_probs: top-k probability | |||
# topk_ids: [num_question, k] | |||
answer_first_token = answer_ids[:, 1] | |||
prob_first_token = F.softmax( | |||
logits, dim=1).index_select( | |||
dim=1, index=answer_first_token) | |||
topk_probs, topk_ids = prob_first_token.topk(k, dim=1) | |||
# answer input: [num_question*k, answer_len] | |||
input_ids = [] | |||
input_atts = [] | |||
for b, topk_id in enumerate(topk_ids): | |||
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |||
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |||
input_ids = torch.cat(input_ids, dim=0) | |||
input_atts = torch.cat(input_atts, dim=0) | |||
targets_ids = input_ids.masked_fill( | |||
input_ids == self.tokenizer.pad_token_id, -100) | |||
# repeat encoder's output for top-k answers | |||
question_states = self._tile(question_states, 0, k) | |||
question_atts = self._tile(question_atts, 0, k) | |||
output = self.text_decoder( | |||
input_ids, | |||
attention_mask=input_atts, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
labels=targets_ids, | |||
return_dict=True, | |||
reduction='none') | |||
answer_loss = output.loss | |||
answer_loss = answer_loss.view(input_ids.size(0), -1) | |||
# topk_prob: first token probability | |||
topk_probs = topk_probs.view(-1, 1) | |||
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) | |||
# re-calculate log probabilities for the answer sequences using chain rule | |||
log_probs_sum = log_probs.sum(1) | |||
log_probs_sum = log_probs_sum.view(num_ques, k) | |||
topk_probs = F.softmax(log_probs_sum, dim=-1) | |||
# get top-k after re-ranking | |||
topk_probs, rerank_id = topk_probs.topk(k, dim=1) | |||
topk_ids = torch.gather(topk_ids, 1, rerank_id) | |||
return topk_ids, topk_probs | |||
class MPlugForVisualQuestionAnswering(MPlug): | |||
@@ -2111,6 +2046,8 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||
merge_text_attention = torch.cat( | |||
[image_atts, question.attention_mask], 1) | |||
if k is None: | |||
k = [1] * question_output.shape[0] | |||
question_states = [] | |||
question_atts = [] | |||
for b, n in enumerate(k): | |||
@@ -2177,6 +2114,8 @@ class MPlugForVisualQuestionAnswering(MPlug): | |||
return_dict=True, | |||
reduction='none', | |||
) | |||
if weights is None: | |||
weights = 1 | |||
loss = weights * answer_output.loss | |||
loss = loss.sum() / image.size(0) | |||
@@ -2262,50 +2201,17 @@ class MPLUGForImageCaption(MPlug): | |||
if train: | |||
answer_targets = answer.input_ids.masked_fill( | |||
answer.input_ids == self.tokenizer.pad_token_id, -100) | |||
text_output = self.text_encoder( | |||
question.input_ids, | |||
attention_mask=question.attention_mask, | |||
return_dict=True) | |||
text_embeds = text_output.last_hidden_state | |||
fusion_output = self.fusion_encoder( | |||
encoder_embeds=text_embeds, | |||
attention_mask=question.attention_mask, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
return_dict=False) | |||
image_output, question_output = fusion_output | |||
question_output = torch.cat([image_output, question_output], 1) | |||
merge_text_attention = torch.cat( | |||
[image_atts, question.attention_mask], 1) | |||
answer_output = self.text_decoder( | |||
answer.input_ids, | |||
attention_mask=answer.attention_mask, | |||
encoder_hidden_states=question_output, | |||
encoder_attention_mask=merge_text_attention, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
labels=answer_targets, | |||
return_dict=True, | |||
reduction='none') | |||
loss = answer_output.loss | |||
return loss | |||
else: | |||
text_output = self.text_encoder( | |||
question.input_ids, | |||
attention_mask=question.attention_mask, | |||
return_dict=True) | |||
text_embeds = text_output.last_hidden_state | |||
fusion_output = self.fusion_encoder( | |||
encoder_embeds=text_embeds, | |||
attention_mask=question.attention_mask, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
return_dict=False) | |||
image_output, question_output = fusion_output | |||
question_output = torch.cat([image_output, question_output], 1) | |||
merge_text_attention = torch.cat( | |||
[image_atts, question.attention_mask], 1) | |||
topk_ids, topk_probs = self.generation(question_output, | |||
merge_text_attention) | |||
topk_ids, topk_probs = self.generation(image_embeds, image_atts) | |||
return topk_ids, topk_probs |
@@ -1,4 +1,4 @@ | |||
from typing import Dict | |||
from typing import Dict, List | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
@@ -25,12 +25,6 @@ class MPlugForAllTasks(TorchModel): | |||
self.model = MPlug.from_pretrained(model_dir) | |||
self.tokenizer = self.model.tokenizer | |||
def train(self): | |||
return self.model.train() | |||
def eval(self): | |||
return self.model.eval() | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
@@ -45,13 +39,43 @@ class MPlugForAllTasks(TorchModel): | |||
} | |||
""" | |||
topk_ids, _ = self.model(**input) | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
pred_string = self.tokenizer.decode(topk_ids[0][0]) | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string = pred_string.strip() | |||
return pred_string | |||
if not self.training and 'answer_input_ids' not in input: | |||
topk_ids, _ = self.model(**input) | |||
pred_string: str = self.tokenizer.decode(topk_ids[0][0]) | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string = pred_string.strip() | |||
return pred_string | |||
else: | |||
import addict | |||
question = addict.Dict( | |||
input_ids=input['question_input_ids'], | |||
attention_mask=input['question_attention_mask']) | |||
answer = addict.Dict( | |||
input_ids=input['answer_input_ids'], | |||
attention_mask=input['answer_attention_mask']) | |||
output = self.model( | |||
input['image'], question, answer, train=self.training) | |||
if self.training: | |||
return {'loss': output} | |||
topk_ids, _ = output | |||
preds: List[str] = [ | |||
self.tokenizer.decode(batch[0]) for batch in topk_ids | |||
] | |||
for i in range(len(preds)): | |||
for _old, _new in replace_tokens_bert: | |||
preds[i] = preds[i].replace(_old, _new) | |||
preds[i] = preds[i].strip() | |||
tgts: List[str] = [ | |||
self.tokenizer.decode(batch) | |||
for batch in input['answer_input_ids'].cpu().numpy().tolist() | |||
] | |||
for i in range(len(tgts)): | |||
for _old, _new in replace_tokens_bert: | |||
tgts[i] = tgts[i].replace(_old, _new) | |||
preds[i] = preds[i].strip() | |||
return {'preds': preds, 'tgts': tgts} |
@@ -60,5 +60,6 @@ class GPT3ForTextGeneration(TorchModel): | |||
sample_output = self.model.generate(**gen_params) | |||
return { | |||
OutputKeys.TEXT: | |||
self.tokenizer.decode(sample_output[0], skip_special_tokens=True) | |||
self.tokenizer.decode(sample_output[0], | |||
skip_special_tokens=True).replace(' ', '') | |||
} |
@@ -29,20 +29,19 @@ class PalmForTextGeneration(TorchModel): | |||
self.generator = Translator(self.model) | |||
def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', | |||
''), | |||
(r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), | |||
('[CLS]', ''), ('[UNK]', ''), (' ', '')) | |||
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), | |||
('<pad>', ''), ('<s>', ''), ('</s>', ''), | |||
('<unk>', ' '), ('<q>', '. ')) | |||
replace_tokens = replace_tokens_roberta \ | |||
if self.model.config.encoder == 'roberta' else replace_tokens_bert | |||
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | |||
for _old, _new in replace_tokens_bert: | |||
for _old, _new in replace_tokens: | |||
strings = [s.replace(_old, _new) for s in strings] | |||
for _old, _new in replace_tokens_roberta: | |||
strings = [s.replace(_old, _new) for s in strings] | |||
for s in strings: | |||
s.strip() | |||
return strings | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
@@ -9,7 +9,7 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.pipelines.base import Input | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import Fields, ModelFile, Tasks | |||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
from .ofa import * # noqa | |||
@@ -91,9 +91,16 @@ class OfaPreprocessor(Preprocessor): | |||
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | |||
class MPlugPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
def __init__(self, | |||
model_dir: str, | |||
mode: str = ModeKeys.INFERENCE, | |||
tokenizer_max_length: int = 25, | |||
*args, | |||
**kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.model_dir = model_dir | |||
self.mode = mode | |||
self.tokenizer_max_length = tokenizer_max_length | |||
self._tokenizer = None | |||
self._patch_resize_transform = None | |||
@@ -128,40 +135,51 @@ class MPlugPreprocessor(Preprocessor): | |||
def __call__(self, *args, **kwargs): | |||
call_mapping = { | |||
Tasks.visual_question_answering: self.vqa_call, | |||
Tasks.image_captioning: self.caption_call | |||
Tasks.visual_question_answering: self.image_text_call, | |||
Tasks.image_captioning: self.image_text_call, | |||
} | |||
self.cfg = Config.from_file( | |||
osp.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
return call_mapping[self.cfg.task](*args, **kwargs) | |||
def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: | |||
image: Image.Image = data[0] if isinstance(data, | |||
tuple) else data['image'] | |||
question: str = data[1] if isinstance(data, | |||
tuple) else data['question'] | |||
image = image.convert('RGB') | |||
image = self.patch_resize_transform(image) | |||
image = torch.stack([image], dim=0) | |||
question = self.tokenizer([question.lower()], | |||
padding='longest', | |||
return_tensors='pt') | |||
return {'image': image, 'question': question, 'train': False} | |||
def caption_call( | |||
def image_text_call( | |||
self, data: Union[Image.Image, tuple, | |||
Dict[str, Any]]) -> Dict[str, Any]: | |||
if isinstance(data, Image.Image): | |||
if isinstance(data, (Image.Image, str)): | |||
image = data | |||
elif isinstance(data, tuple): | |||
image = data[0] | |||
else: | |||
image = data['image'] | |||
if isinstance(image, str): | |||
image = Image.open(image) | |||
question = '' if self.cfg.task != Tasks.visual_question_answering \ | |||
else data[1 if isinstance(data, tuple) else 'question'] | |||
image = image.convert('RGB') | |||
image = self.patch_resize_transform(image) | |||
image = torch.stack([image], dim=0) | |||
question = self.tokenizer('', return_tensors='pt') | |||
return {'image': image, 'question': question, 'train': False} | |||
question = self.tokenizer( | |||
question.lower(), | |||
padding='max_length', | |||
truncation=True, | |||
max_length=self.tokenizer_max_length, | |||
return_tensors='pt') | |||
if self.mode == ModeKeys.INFERENCE: | |||
image = torch.stack([image], dim=0) | |||
return {'image': image, 'question': question, 'train': False} | |||
else: | |||
answer = data['answer'] | |||
answer = self.tokenizer( | |||
answer, | |||
padding='max_length', | |||
truncation=True, | |||
max_length=self.tokenizer_max_length, | |||
return_tensors='pt') | |||
return { | |||
'image': image, | |||
'question_input_ids': question.input_ids.squeeze(), | |||
'question_attention_mask': question.attention_mask.squeeze(), | |||
'answer_input_ids': answer.input_ids.squeeze(), | |||
'answer_attention_mask': answer.attention_mask.squeeze(), | |||
} |
@@ -0,0 +1,128 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
from PIL import Image | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.multi_modal import MPlugForAllTasks | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import EpochBasedTrainer, build_trainer | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestFinetuneMPlug(unittest.TestCase): | |||
def setUp(self): | |||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
datadict = MsDataset.load('coco_captions_small_slice') | |||
self.train_dataset = MsDataset(datadict['train'].to_hf_dataset().map( | |||
lambda _: { | |||
'question': 'what the picture describes?' | |||
}).rename_column('image:FILE', | |||
'image').rename_column('answer:Value', 'answer')) | |||
self.test_dataset = MsDataset(datadict['test'].to_hf_dataset().map( | |||
lambda _: { | |||
'question': 'what the picture describes?' | |||
}).rename_column('image:FILE', | |||
'image').rename_column('answer:Value', 'answer')) | |||
def tearDown(self): | |||
shutil.rmtree(self.tmp_dir) | |||
super().tearDown() | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_with_caption(self): | |||
kwargs = dict( | |||
model='damo/mplug_image-captioning_coco_base_en', | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
work_dir=self.tmp_dir) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(3): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_caption_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
cache_path = snapshot_download( | |||
'damo/mplug_image-captioning_coco_base_en') | |||
model = MPlugForAllTasks.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
max_epochs=2, | |||
work_dir=self.tmp_dir) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(2): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer_with_vqa(self): | |||
kwargs = dict( | |||
model='damo/mplug_visual-question-answering_coco_large_en', | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
work_dir=self.tmp_dir) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(3): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_trainer_with_vqa_with_model_and_args(self): | |||
tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(tmp_dir): | |||
os.makedirs(tmp_dir) | |||
cache_path = snapshot_download( | |||
'damo/mplug_visual-question-answering_coco_large_en') | |||
model = MPlugForAllTasks.from_pretrained(cache_path) | |||
kwargs = dict( | |||
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
model=model, | |||
train_dataset=self.train_dataset, | |||
eval_dataset=self.test_dataset, | |||
max_epochs=2, | |||
work_dir=self.tmp_dir) | |||
trainer: EpochBasedTrainer = build_trainer( | |||
name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
trainer.train() | |||
results_files = os.listdir(self.tmp_dir) | |||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
for i in range(2): | |||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |