添加 MPLUG 模型 image-captioning 任务 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9745826master
@@ -218,7 +218,7 @@ class Preprocessors(object): | |||
# multi-modal preprocessor | |||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
mplug_visual_question_answering = 'mplug-visual-question-answering' | |||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
class Metrics(object): | |||
@@ -9,8 +9,7 @@ if TYPE_CHECKING: | |||
from .gemm import GEMMForMultiModalEmbedding | |||
from .diffusion import DiffusionForTextToImageSynthesis | |||
from .mmr import VideoCLIPForMultiModalEmbedding | |||
from .mplug_for_visual_question_answering import \ | |||
MPlugForVisualQuestionAnswering | |||
from .mplug_for_all_tasks import MPlugForAllTasks | |||
from .ofa_for_all_tasks import OfaForAllTasks | |||
from .ofa_for_text_to_image_synthesis_model import \ | |||
OfaForTextToImageSynthesis | |||
@@ -21,8 +20,7 @@ else: | |||
'diffusion': ['DiffusionForTextToImageSynthesis'], | |||
'gemm': ['GEMMForMultiModalEmbedding'], | |||
'mmr': ['VideoCLIPForMultiModalEmbedding'], | |||
'mplug_for_visual_question_answering': | |||
['MPlugForVisualQuestionAnswering'], | |||
'mplug_for_all_tasks': ['MPlugForAllTasks'], | |||
'ofa_for_all_tasks': ['OfaForAllTasks'], | |||
'ofa_for_text_to_image_synthesis_model': | |||
['OfaForTextToImageSynthesis'] | |||
@@ -14,4 +14,4 @@ | |||
# limitations under the License. | |||
from .configuration_mplug import MPlugConfig | |||
from .modeling_mplug import CONFIG_NAME, MPlugForVisualQuestionAnswering | |||
from .modeling_mplug import CONFIG_NAME, MPlug |
@@ -15,14 +15,14 @@ | |||
# limitations under the License. | |||
""" MPLUG model configuration """ | |||
import os | |||
from collections import OrderedDict | |||
from typing import Any, Dict, Mapping, Union | |||
from typing import Any, Dict, Union | |||
import yaml | |||
from transformers import PretrainedConfig | |||
from transformers.onnx import OnnxConfig | |||
from transformers.utils import logging | |||
from modelscope.utils.constant import Tasks | |||
logger = logging.get_logger(__name__) | |||
@@ -32,6 +32,7 @@ class MPlugConfig(PretrainedConfig): | |||
def __init__( | |||
self, | |||
task=Tasks.visual_question_answering, | |||
bert_config='config_bert.json', | |||
image_res=504, | |||
batch_size_train=128, | |||
@@ -64,7 +65,9 @@ class MPlugConfig(PretrainedConfig): | |||
clip_transformer_heads=12, | |||
clip_transformer_layers=12, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self.task = task | |||
self.bert_config = bert_config | |||
self.image_res = image_res | |||
self.batch_size_train = batch_size_train | |||
@@ -103,23 +106,3 @@ class MPlugConfig(PretrainedConfig): | |||
with open(yaml_file, 'r') as reader: | |||
config_dict = yaml.load(reader, Loader=yaml.Loader) | |||
return cls(**config_dict) | |||
class MPlugOnnxConfig(OnnxConfig): | |||
@property | |||
def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
return OrderedDict([ | |||
('input_ids', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
('attention_mask', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
('token_type_ids', { | |||
0: 'batch', | |||
1: 'sequence' | |||
}), | |||
]) |
@@ -1725,7 +1725,116 @@ class BertLMHeadModel(BertPreTrainedModel): | |||
return reordered_past | |||
class MPlugForVisualQuestionAnswering(PreTrainedModel): | |||
class BertPrefixModel(BertPreTrainedModel): | |||
_keys_to_ignore_on_load_unexpected = [r'pooler'] | |||
_keys_to_ignore_on_load_missing = [ | |||
r'position_ids', r'predictions.decoder.bias' | |||
] | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.bert = BertModel(config, add_pooling_layer=False) | |||
self.cls = BertOnlyMLMHead(config) | |||
self.init_weights() | |||
def get_output_embeddings(self): | |||
return self.cls.predictions.decoder | |||
def set_output_embeddings(self, new_embeddings): | |||
self.cls.predictions.decoder = new_embeddings | |||
@add_start_docstrings_to_model_forward( | |||
BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
@add_code_sample_docstrings( | |||
processor_class=_TOKENIZER_FOR_DOC, | |||
checkpoint='bert-base-uncased', | |||
output_type=CausalLMOutputWithCrossAttentions, | |||
config_class=_CONFIG_FOR_DOC, | |||
) | |||
def forward( | |||
self, | |||
input_ids=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
inputs_embeds=None, | |||
encoder_hidden_states=None, | |||
encoder_attention_mask=None, | |||
labels=None, | |||
past_key_values=None, | |||
use_cache=None, | |||
output_attentions=None, | |||
output_hidden_states=None, | |||
return_dict=None, | |||
is_decoder=True, | |||
reduction='mean', | |||
soft_labels=None, | |||
alpha=0, | |||
return_logits=False, | |||
): | |||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |||
if labels is not None: | |||
use_cache = False | |||
outputs = self.bert( | |||
input_ids, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
head_mask=head_mask, | |||
inputs_embeds=inputs_embeds, | |||
encoder_hidden_states=encoder_hidden_states, | |||
encoder_attention_mask=encoder_attention_mask, | |||
past_key_values=past_key_values, | |||
use_cache=use_cache, | |||
output_attentions=output_attentions, | |||
output_hidden_states=output_hidden_states, | |||
return_dict=return_dict, | |||
is_decoder=is_decoder, | |||
) | |||
sequence_output = outputs[0] | |||
prediction_scores = self.cls(sequence_output) | |||
if return_logits: | |||
return prediction_scores[:, :-1, :].contiguous() | |||
lm_loss = None | |||
if labels is not None: | |||
# we are doing next-token prediction; shift prediction scores and input ids by one | |||
shifted_prediction_scores = prediction_scores[:, : | |||
-1, :].contiguous() | |||
labels = labels[:, 1:].contiguous() | |||
loss_fct = CrossEntropyLoss() | |||
lm_loss = loss_fct( | |||
shifted_prediction_scores.view(-1, self.config.vocab_size), | |||
labels.view(-1)) | |||
if soft_labels is not None: | |||
loss_distill = -torch.sum( | |||
F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, | |||
dim=-1) | |||
loss_distill = loss_distill[labels != -100].mean() | |||
lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill | |||
if not return_dict: | |||
output = (prediction_scores, ) + outputs[2:] | |||
return ((lm_loss, ) + output) if lm_loss is not None else output | |||
return CausalLMOutputWithCrossAttentions( | |||
loss=lm_loss, | |||
logits=prediction_scores, | |||
past_key_values=outputs.past_key_values, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
cross_attentions=outputs.cross_attentions, | |||
) | |||
class MPlug(PreTrainedModel): | |||
config_class = MPlugConfig | |||
def __init__(self, config): | |||
@@ -1739,16 +1848,19 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): | |||
self.config_encoder, add_pooling_layer=False) | |||
self.fusion_encoder = FusionModel( | |||
self.config_fusion, add_pooling_layer=False) | |||
self.text_decoder = BertLMHeadModel(self.config_decoder) | |||
self.init_distill(config) | |||
self.beam_generator = TextGenerator(config, self.text_decoder) | |||
@classmethod | |||
def from_pretrained(cls, model_dir, load_checkpoint=True): | |||
config = MPlugConfig.from_yaml_file( | |||
from modelscope.utils.constant import Tasks | |||
task_mapping = { | |||
Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, | |||
Tasks.image_captioning: MPLUGForImageCaption | |||
} | |||
config = cls.config_class.from_yaml_file( | |||
os.path.join(model_dir, CONFIG_NAME)) | |||
config.model_dir = model_dir | |||
model = cls(config) | |||
model = task_mapping[config.task](config) | |||
if load_checkpoint: | |||
checkpoint_path = os.path.join(model_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
@@ -1803,6 +1915,161 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): | |||
clip_model.visual.positional_embedding = pos_embed | |||
return clip_model | |||
def forward(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def module_setting(self, config): | |||
bert_config_path = os.path.join(config.model_dir, config.bert_config) | |||
self.config_encoder = BertConfig.from_json_file(bert_config_path) | |||
self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers | |||
self.config_fusion = BertConfig.from_json_file(bert_config_path) | |||
self.config_decoder = BertConfig.from_json_file(bert_config_path) | |||
self.config_decoder.add_cross_attention = True | |||
self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers | |||
self.large = False | |||
if self.config_encoder.hidden_size != config.vision_width: | |||
self.visn_fc = nn.Linear(config.vision_width, | |||
self.config_encoder.hidden_size) | |||
self.visn_layer_norm = nn.LayerNorm( | |||
self.config_encoder.hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) | |||
self.large = True | |||
@torch.no_grad() | |||
def copy_params(self): | |||
for model_pair in self.model_pairs: | |||
for param, param_m in zip(model_pair[0].parameters(), | |||
model_pair[1].parameters()): | |||
param_m.data.copy_(param.data) # initialize | |||
param_m.requires_grad = False # not update by gradient | |||
@torch.no_grad() | |||
def _momentum_update(self): | |||
for model_pair in self.model_pairs: | |||
for param, param_m in zip(model_pair[0].parameters(), | |||
model_pair[1].parameters()): | |||
param_m.data = param_m.data * self.momentum + param.data * ( | |||
1. - self.momentum) | |||
def generation(self, question_states, question_atts, out_size=1): | |||
encoder_inputs = [question_states, question_atts] | |||
topk_ids, topk_scores = self.beam_generator.translate_batch( | |||
encoder_inputs, out_size=out_size) | |||
return topk_ids, topk_scores | |||
@staticmethod | |||
def _tile(x, dim, n_tile): | |||
import numpy as np | |||
init_dim = x.size(dim) | |||
repeat_idx = [1] * x.dim() | |||
repeat_idx[dim] = n_tile | |||
x = x.repeat(*(repeat_idx)) | |||
order_index = torch.LongTensor( | |||
np.concatenate( | |||
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |||
return torch.index_select(x, dim, order_index.to(x.device)) | |||
def rank_answer(self, question_states, question_atts, answer_ids, | |||
answer_atts, k): | |||
num_ques = question_states.size(0) | |||
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token | |||
start_output = self.text_decoder( | |||
start_ids, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
return_dict=True, | |||
reduction='none') | |||
logits = start_output.logits[:, 0, :] # first token's logit | |||
# topk_probs: top-k probability | |||
# topk_ids: [num_question, k] | |||
answer_first_token = answer_ids[:, 1] | |||
prob_first_token = F.softmax( | |||
logits, dim=1).index_select( | |||
dim=1, index=answer_first_token) | |||
topk_probs, topk_ids = prob_first_token.topk(k, dim=1) | |||
# answer input: [num_question*k, answer_len] | |||
input_ids = [] | |||
input_atts = [] | |||
for b, topk_id in enumerate(topk_ids): | |||
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |||
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |||
input_ids = torch.cat(input_ids, dim=0) | |||
input_atts = torch.cat(input_atts, dim=0) | |||
targets_ids = input_ids.masked_fill( | |||
input_ids == self.tokenizer.pad_token_id, -100) | |||
# repeat encoder's output for top-k answers | |||
question_states = self._tile(question_states, 0, k) | |||
question_atts = self._tile(question_atts, 0, k) | |||
output = self.text_decoder( | |||
input_ids, | |||
attention_mask=input_atts, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
labels=targets_ids, | |||
return_dict=True, | |||
reduction='none') | |||
answer_loss = output.loss | |||
answer_loss = answer_loss.view(input_ids.size(0), -1) | |||
# topk_prob: first token probability | |||
topk_probs = topk_probs.view(-1, 1) | |||
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) | |||
# re-calculate log probabilities for the answer sequences using chain rule | |||
log_probs_sum = log_probs.sum(1) | |||
log_probs_sum = log_probs_sum.view(num_ques, k) | |||
topk_probs = F.softmax(log_probs_sum, dim=-1) | |||
# get top-k after re-ranking | |||
topk_probs, rerank_id = topk_probs.topk(k, dim=1) | |||
topk_ids = torch.gather(topk_ids, 1, rerank_id) | |||
return topk_ids, topk_probs | |||
class MPlugForVisualQuestionAnswering(MPlug): | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.text_decoder = BertLMHeadModel(self.config_decoder) | |||
self.beam_generator = TextGenerator(config, self.text_decoder) | |||
self.init_distill(config) | |||
def init_distill(self, config): | |||
self.distill = config.distill | |||
if self.distill: | |||
self.visual_encoder_m = self._initialize_clip(config) | |||
self.text_encoder_m = BertModel( | |||
self.config_encoder, add_pooling_layer=False) | |||
self.fusion_encoder_m = FusionModel( | |||
self.config_fusion, add_pooling_layer=False) | |||
self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||
self.model_pairs = [ | |||
[self.visual_encoder, self.visual_encoder_m], | |||
[self.text_encoder, self.text_encoder_m], | |||
[self.text_decoder, self.text_decoder_m], | |||
] | |||
if self.config_encoder.hidden_size != config.vision_width: | |||
self.visn_fc_m = nn.Linear(config.vision_width, | |||
self.config_encoder.hidden_size) | |||
self.visn_layer_norm_m = nn.LayerNorm( | |||
self.config_encoder.hidden_size, eps=1e-12) | |||
self.dropout_m = nn.Dropout( | |||
self.config_encoder.hidden_dropout_prob) | |||
self.model_pairs.extend( | |||
[[self.visn_fc, self.visn_fc_m], | |||
[self.visn_layer_norm, self.visn_layer_norm_m]]) | |||
self.copy_params() | |||
self.momentum = 0.995 | |||
def forward(self, | |||
image, | |||
question, | |||
@@ -1935,145 +2202,110 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): | |||
merge_text_attention) | |||
return topk_ids, topk_probs | |||
def module_setting(self, config): | |||
bert_config_path = os.path.join(config.model_dir, config.bert_config) | |||
self.config_encoder = BertConfig.from_json_file(bert_config_path) | |||
self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers | |||
self.config_fusion = BertConfig.from_json_file(bert_config_path) | |||
self.config_decoder = BertConfig.from_json_file(bert_config_path) | |||
self.config_decoder.add_cross_attention = True | |||
self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers | |||
self.large = False | |||
if self.config_encoder.hidden_size != config.vision_width: | |||
self.visn_fc = nn.Linear(config.vision_width, | |||
self.config_encoder.hidden_size) | |||
self.visn_layer_norm = nn.LayerNorm( | |||
self.config_encoder.hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) | |||
self.large = True | |||
def init_distill(self, config): | |||
self.distill = config.distill | |||
if self.distill: | |||
self.visual_encoder_m = self._initialize_clip(config) | |||
self.text_encoder_m = BertModel( | |||
self.config_encoder, add_pooling_layer=False) | |||
self.fusion_encoder_m = FusionModel( | |||
self.config_fusion, add_pooling_layer=False) | |||
self.text_decoder_m = BertLMHeadModel(self.config_decoder) | |||
self.model_pairs = [ | |||
[self.visual_encoder, self.visual_encoder_m], | |||
[self.text_encoder, self.text_encoder_m], | |||
[self.text_decoder, self.text_decoder_m], | |||
] | |||
if self.config_encoder.hidden_size != config.vision_width: | |||
self.visn_fc_m = nn.Linear(config.vision_width, | |||
self.config_encoder.hidden_size) | |||
self.visn_layer_norm_m = nn.LayerNorm( | |||
self.config_encoder.hidden_size, eps=1e-12) | |||
self.dropout_m = nn.Dropout( | |||
self.config_encoder.hidden_dropout_prob) | |||
self.model_pairs.extend( | |||
[[self.visn_fc, self.visn_fc_m], | |||
[self.visn_layer_norm, self.visn_layer_norm_m]]) | |||
self.copy_params() | |||
self.momentum = 0.995 | |||
@torch.no_grad() | |||
def copy_params(self): | |||
for model_pair in self.model_pairs: | |||
for param, param_m in zip(model_pair[0].parameters(), | |||
model_pair[1].parameters()): | |||
param_m.data.copy_(param.data) # initialize | |||
param_m.requires_grad = False # not update by gradient | |||
@torch.no_grad() | |||
def _momentum_update(self): | |||
for model_pair in self.model_pairs: | |||
for param, param_m in zip(model_pair[0].parameters(), | |||
model_pair[1].parameters()): | |||
param_m.data = param_m.data * self.momentum + param.data * ( | |||
1. - self.momentum) | |||
def generation(self, question_states, question_atts): | |||
encoder_inputs = [question_states, question_atts] | |||
topk_ids, topk_scores = self.beam_generator.translate_batch( | |||
encoder_inputs) | |||
return topk_ids, topk_scores | |||
@staticmethod | |||
def _tile(x, dim, n_tile): | |||
import numpy as np | |||
init_dim = x.size(dim) | |||
repeat_idx = [1] * x.dim() | |||
repeat_idx[dim] = n_tile | |||
x = x.repeat(*(repeat_idx)) | |||
order_index = torch.LongTensor( | |||
np.concatenate( | |||
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |||
return torch.index_select(x, dim, order_index.to(x.device)) | |||
def rank_answer(self, question_states, question_atts, answer_ids, | |||
answer_atts, k): | |||
num_ques = question_states.size(0) | |||
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token | |||
start_output = self.text_decoder( | |||
start_ids, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
return_dict=True, | |||
reduction='none') | |||
logits = start_output.logits[:, 0, :] # first token's logit | |||
class MPLUGForImageCaption(MPlug): | |||
# topk_probs: top-k probability | |||
# topk_ids: [num_question, k] | |||
answer_first_token = answer_ids[:, 1] | |||
prob_first_token = F.softmax( | |||
logits, dim=1).index_select( | |||
dim=1, index=answer_first_token) | |||
topk_probs, topk_ids = prob_first_token.topk(k, dim=1) | |||
# answer input: [num_question*k, answer_len] | |||
input_ids = [] | |||
input_atts = [] | |||
for b, topk_id in enumerate(topk_ids): | |||
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |||
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |||
input_ids = torch.cat(input_ids, dim=0) | |||
input_atts = torch.cat(input_atts, dim=0) | |||
targets_ids = input_ids.masked_fill( | |||
input_ids == self.tokenizer.pad_token_id, -100) | |||
# repeat encoder's output for top-k answers | |||
question_states = self._tile(question_states, 0, k) | |||
question_atts = self._tile(question_atts, 0, k) | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.text_decoder = BertPrefixModel(self.config_decoder) | |||
self.beam_generator = TextGenerator(config, self.text_decoder) | |||
output = self.text_decoder( | |||
input_ids, | |||
attention_mask=input_atts, | |||
encoder_hidden_states=question_states, | |||
encoder_attention_mask=question_atts, | |||
labels=targets_ids, | |||
return_dict=True, | |||
reduction='none') | |||
def beam_search(self, | |||
image, | |||
question, | |||
answer=None, | |||
train=True, | |||
out_size=5): | |||
image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) | |||
if self.large: | |||
image_embeds = self.dropout( | |||
self.visn_layer_norm(self.visn_fc(image_embeds))) | |||
image_atts = torch.ones( | |||
image_embeds.size()[:-1], dtype=torch.long).to(image.device) | |||
text_output = self.text_encoder( | |||
question.input_ids, | |||
attention_mask=question.attention_mask, | |||
return_dict=True) | |||
text_embeds = text_output.last_hidden_state | |||
fusion_output = self.fusion_encoder( | |||
encoder_embeds=text_embeds, | |||
attention_mask=question.attention_mask, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
return_dict=False) | |||
image_output, question_output = fusion_output | |||
question_output = torch.cat([image_output, question_output], 1) | |||
merge_text_attention = torch.cat([image_atts, question.attention_mask], | |||
1) | |||
topk_ids, topk_probs = self.generation( | |||
question_output, merge_text_attention, out_size=out_size) | |||
return topk_ids, topk_probs | |||
answer_loss = output.loss | |||
answer_loss = answer_loss.view(input_ids.size(0), -1) | |||
def forward(self, | |||
image, | |||
question, | |||
answer=None, | |||
train=True, | |||
out_size=5, | |||
scst=False): | |||
if (scst): | |||
return self.beam_search( | |||
image, question, answer, train=True, out_size=out_size) | |||
image = image.to(dtype=next(self.parameters()).dtype) | |||
image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) | |||
if self.large: | |||
image_embeds = self.dropout( | |||
self.visn_layer_norm(self.visn_fc(image_embeds))) | |||
image_atts = torch.ones( | |||
image_embeds.size()[:-1], dtype=torch.long).to(image.device) | |||
# topk_prob: first token probability | |||
topk_probs = topk_probs.view(-1, 1) | |||
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) | |||
if train: | |||
answer_targets = answer.input_ids.masked_fill( | |||
answer.input_ids == self.tokenizer.pad_token_id, -100) | |||
text_output = self.text_encoder( | |||
question.input_ids, | |||
attention_mask=question.attention_mask, | |||
return_dict=True) | |||
text_embeds = text_output.last_hidden_state | |||
fusion_output = self.fusion_encoder( | |||
encoder_embeds=text_embeds, | |||
attention_mask=question.attention_mask, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
return_dict=False) | |||
# re-calculate log probabilities for the answer sequences using chain rule | |||
log_probs_sum = log_probs.sum(1) | |||
log_probs_sum = log_probs_sum.view(num_ques, k) | |||
image_output, question_output = fusion_output | |||
topk_probs = F.softmax(log_probs_sum, dim=-1) | |||
# get top-k after re-ranking | |||
topk_probs, rerank_id = topk_probs.topk(k, dim=1) | |||
topk_ids = torch.gather(topk_ids, 1, rerank_id) | |||
question_output = torch.cat([image_output, question_output], 1) | |||
merge_text_attention = torch.cat( | |||
[image_atts, question.attention_mask], 1) | |||
return topk_ids, topk_probs | |||
answer_output = self.text_decoder( | |||
answer.input_ids, | |||
attention_mask=answer.attention_mask, | |||
encoder_hidden_states=question_output, | |||
encoder_attention_mask=merge_text_attention, | |||
labels=answer_targets, | |||
return_dict=True, | |||
reduction='none') | |||
loss = answer_output.loss | |||
return loss | |||
else: | |||
text_output = self.text_encoder( | |||
question.input_ids, | |||
attention_mask=question.attention_mask, | |||
return_dict=True) | |||
text_embeds = text_output.last_hidden_state | |||
fusion_output = self.fusion_encoder( | |||
encoder_embeds=text_embeds, | |||
attention_mask=question.attention_mask, | |||
encoder_hidden_states=image_embeds, | |||
encoder_attention_mask=image_atts, | |||
return_dict=False) | |||
image_output, question_output = fusion_output | |||
question_output = torch.cat([image_output, question_output], 1) | |||
merge_text_attention = torch.cat( | |||
[image_atts, question.attention_mask], 1) | |||
topk_ids, topk_probs = self.generation(question_output, | |||
merge_text_attention) | |||
return topk_ids, topk_probs |
@@ -6,12 +6,13 @@ from modelscope.models.base import Tensor | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['MPlugForVisualQuestionAnswering'] | |||
__all__ = ['MPlugForAllTasks'] | |||
@MODELS.register_module( | |||
Tasks.visual_question_answering, module_name=Models.mplug) | |||
class MPlugForVisualQuestionAnswering(TorchModel): | |||
@MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) | |||
class MPlugForAllTasks(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the mplug model from the `model_dir` path. | |||
@@ -20,8 +21,8 @@ class MPlugForVisualQuestionAnswering(TorchModel): | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from modelscope.models.multi_modal.mplug import MPlugForVisualQuestionAnswering | |||
self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) | |||
from modelscope.models.multi_modal.mplug import MPlug | |||
self.model = MPlug.from_pretrained(model_dir) | |||
self.tokenizer = self.model.tokenizer | |||
def train(self): | |||
@@ -44,4 +45,13 @@ class MPlugForVisualQuestionAnswering(TorchModel): | |||
} | |||
""" | |||
return self.model(**input)[0] | |||
topk_ids, _ = self.model(**input) | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
pred_string = self.tokenizer.decode(topk_ids[0][0]) | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string = pred_string.strip() | |||
return pred_string |
@@ -1,11 +1,15 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.multi_modal import OfaForAllTasks | |||
from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||
from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, | |||
Preprocessor) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -35,9 +39,19 @@ class ImageCaptioningPipeline(Pipeline): | |||
else: | |||
raise NotImplementedError | |||
pipe_model.model.eval() | |||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | |||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | |||
if preprocessor is None: | |||
if isinstance(pipe_model, OfaForAllTasks): | |||
preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||
elif isinstance(pipe_model, MPlugForAllTasks): | |||
preprocessor = MPlugPreprocessor(pipe_model.model_dir) | |||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return super().forward(inputs, **forward_params) | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs | |||
if isinstance(self.model, OfaForAllTasks): | |||
return inputs | |||
return {OutputKeys.CAPTION: inputs} |
@@ -5,13 +5,12 @@ import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.models.multi_modal import (MPlugForVisualQuestionAnswering, | |||
OfaForAllTasks) | |||
from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import (MPlugVisualQuestionAnsweringPreprocessor, | |||
OfaPreprocessor) | |||
from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, | |||
Preprocessor) | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['VisualQuestionAnsweringPipeline'] | |||
@@ -23,9 +22,8 @@ __all__ = ['VisualQuestionAnsweringPipeline'] | |||
class VisualQuestionAnsweringPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[MPlugForVisualQuestionAnswering, str], | |||
preprocessor: Optional[ | |||
MPlugVisualQuestionAnsweringPreprocessor] = None, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create a visual question answering pipeline for prediction | |||
@@ -35,18 +33,12 @@ class VisualQuestionAnsweringPipeline(Pipeline): | |||
""" | |||
model = model if isinstance(model, | |||
Model) else Model.from_pretrained(model) | |||
self.tokenizer = None | |||
if preprocessor is None: | |||
if isinstance(model, OfaForAllTasks): | |||
preprocessor = OfaPreprocessor(model.model_dir) | |||
elif isinstance(model, MPlugForVisualQuestionAnswering): | |||
preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||
model.model_dir) | |||
if isinstance(model, MPlugForVisualQuestionAnswering): | |||
model.eval() | |||
self.tokenizer = model.tokenizer | |||
else: | |||
model.model.eval() | |||
elif isinstance(model, MPlugForAllTasks): | |||
preprocessor = MPlugPreprocessor(model.model_dir) | |||
model.model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def forward(self, inputs: Dict[str, Any], | |||
@@ -64,14 +56,6 @@ class VisualQuestionAnsweringPipeline(Pipeline): | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
if self.tokenizer is None: | |||
if isinstance(self.model, OfaForAllTasks): | |||
return inputs | |||
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
pred_string = self.tokenizer.decode(inputs[0][0]) | |||
for _old, _new in replace_tokens_bert: | |||
pred_string = pred_string.replace(_old, _new) | |||
pred_string.strip() | |||
return {OutputKeys.TEXT: pred_string} | |||
return {OutputKeys.TEXT: inputs} |
@@ -14,8 +14,7 @@ if TYPE_CHECKING: | |||
ImageInstanceSegmentationPreprocessor, | |||
ImageDenoisePreprocessor) | |||
from .kws import WavToLists | |||
from .multi_modal import (OfaPreprocessor, | |||
MPlugVisualQuestionAnsweringPreprocessor) | |||
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) | |||
from .nlp import (Tokenize, SequenceClassificationPreprocessor, | |||
TextGenerationPreprocessor, | |||
TokenClassificationPreprocessor, | |||
@@ -42,8 +41,7 @@ else: | |||
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' | |||
], | |||
'kws': ['WavToLists'], | |||
'multi_modal': | |||
['OfaPreprocessor', 'MPlugVisualQuestionAnsweringPreprocessor'], | |||
'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], | |||
'nlp': [ | |||
'Tokenize', 'SequenceClassificationPreprocessor', | |||
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
@@ -17,7 +17,7 @@ from .ofa.utils.collate import collate_fn | |||
__all__ = [ | |||
'OfaPreprocessor', | |||
'MPlugVisualQuestionAnsweringPreprocessor', | |||
'MPlugPreprocessor', | |||
] | |||
@@ -88,39 +88,55 @@ class OfaPreprocessor(Preprocessor): | |||
@PREPROCESSORS.register_module( | |||
Fields.multi_modal, | |||
module_name=Preprocessors.mplug_visual_question_answering) | |||
class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||
Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) | |||
class MPlugPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via 'bert-base-uncased' tokenizer and configuration | |||
""" | |||
from transformers import BertTokenizer | |||
from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig | |||
super().__init__(*args, **kwargs) | |||
self.model_dir = model_dir | |||
# tokenizer | |||
self.tokenizer = BertTokenizer.from_pretrained( | |||
osp.join(model_dir, ModelFile.VOCAB_FILE)) | |||
self._tokenizer = None | |||
self._patch_resize_transform = None | |||
# load configuration | |||
config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) | |||
@property | |||
def tokenizer(self): | |||
from transformers import BertTokenizer | |||
# Initialize transform | |||
from torchvision import transforms | |||
mean = (0.48145466, 0.4578275, 0.40821073) | |||
std = (0.26862954, 0.26130258, 0.27577711) | |||
if self._tokenizer is None: | |||
self._tokenizer = BertTokenizer.from_pretrained(self.model_dir) | |||
return self._tokenizer | |||
@property | |||
def patch_resize_transform(self): | |||
if self._patch_resize_transform is None: | |||
from torchvision import transforms | |||
from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig | |||
config = MPlugConfig.from_yaml_file( | |||
osp.join(self.model_dir, CONFIG_NAME)) | |||
mean = (0.48145466, 0.4578275, 0.40821073) | |||
std = (0.26862954, 0.26130258, 0.27577711) | |||
self._patch_resize_transform = transforms.Compose([ | |||
transforms.Resize((config.image_res, config.image_res), | |||
interpolation=Image.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=mean, std=std), | |||
]) | |||
return self._patch_resize_transform | |||
def __call__(self, *args, **kwargs): | |||
call_mapping = { | |||
Tasks.visual_question_answering: self.vqa_call, | |||
Tasks.image_captioning: self.caption_call | |||
} | |||
self.patch_resize_transform = transforms.Compose([ | |||
transforms.Resize((config.image_res, config.image_res), | |||
interpolation=Image.BICUBIC), | |||
transforms.ToTensor(), | |||
transforms.Normalize(mean=mean, std=std), | |||
]) | |||
self.cfg = Config.from_file( | |||
osp.join(self.model_dir, ModelFile.CONFIGURATION)) | |||
return call_mapping[self.cfg.task](*args, **kwargs) | |||
def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: | |||
def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: | |||
image: Image.Image = data[0] if isinstance(data, | |||
tuple) else data['image'] | |||
question: str = data[1] if isinstance(data, | |||
@@ -133,3 +149,19 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||
return_tensors='pt') | |||
return {'image': image, 'question': question, 'train': False} | |||
def caption_call( | |||
self, data: Union[Image.Image, tuple, | |||
Dict[str, Any]]) -> Dict[str, Any]: | |||
if isinstance(data, Image.Image): | |||
image = data | |||
elif isinstance(data, tuple): | |||
image = data[0] | |||
else: | |||
image = data['image'] | |||
image = image.convert('RGB') | |||
image = self.patch_resize_transform(image) | |||
image = torch.stack([image], dim=0) | |||
question = self.tokenizer('', return_tensors='pt') | |||
return {'image': image, 'question': question, 'train': False} |
@@ -0,0 +1,59 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from PIL import Image | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class MplugTasksTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_image_captioning_with_model(self): | |||
model = Model.from_pretrained( | |||
'damo/mplug_image-captioning_coco_base_en') | |||
pipeline_caption = pipeline( | |||
task=Tasks.image_captioning, | |||
model=model, | |||
) | |||
image = Image.open('data/test/images/image_mplug_vqa.jpg') | |||
result = pipeline_caption({'image': image}) | |||
print(result[OutputKeys.CAPTION]) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_image_captioning_with_name(self): | |||
pipeline_caption = pipeline( | |||
Tasks.image_captioning, | |||
model='damo/mplug_image-captioning_coco_base_en') | |||
image = Image.open('data/test/images/image_mplug_vqa.jpg') | |||
result = pipeline_caption({'image': image}) | |||
print(result[OutputKeys.CAPTION]) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_visual_question_answering_with_model(self): | |||
model = Model.from_pretrained( | |||
'damo/mplug_visual-question-answering_coco_large_en') | |||
pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) | |||
image = Image.open('data/test/images/image_mplug_vqa.jpg') | |||
question = 'What is the woman doing?' | |||
input = {'image': image, 'question': question} | |||
result = pipeline_vqa(input) | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_visual_question_answering_with_name(self): | |||
model = 'damo/mplug_visual-question-answering_coco_large_en' | |||
pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) | |||
image = Image.open('data/test/images/image_mplug_vqa.jpg') | |||
question = 'What is the woman doing?' | |||
input = {'image': image, 'question': question} | |||
result = pipeline_vqa(input) | |||
print(result) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -1,65 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from PIL import Image | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.multi_modal import VisualQuestionAnsweringPipeline | |||
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class VisualQuestionAnsweringTest(unittest.TestCase): | |||
def setUp(self): | |||
self.model_id = 'damo/mplug_visual-question-answering_coco_large_en' | |||
self.input_vqa = { | |||
'image': Image.open('data/test/images/image_mplug_vqa.jpg'), | |||
'question': 'What is the woman doing?', | |||
} | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run(self): | |||
cache_path = snapshot_download(self.model_id) | |||
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path) | |||
model = MPlugForVisualQuestionAnswering(cache_path) | |||
pipeline1 = VisualQuestionAnsweringPipeline( | |||
model, preprocessor=preprocessor) | |||
pipeline2 = pipeline( | |||
Tasks.visual_question_answering, | |||
model=model, | |||
preprocessor=preprocessor) | |||
print(f"question: {self.input_vqa['question']}") | |||
print(f'pipeline1: {pipeline1(self.input_vqa)}') | |||
print(f'pipeline2: {pipeline2(self.input_vqa)}') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||
model.model_dir) | |||
pipeline_vqa = pipeline( | |||
task=Tasks.visual_question_answering, | |||
model=model, | |||
preprocessor=preprocessor) | |||
print(pipeline_vqa(self.input_vqa)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_vqa = pipeline( | |||
Tasks.visual_question_answering, model=self.model_id) | |||
print(pipeline_vqa(self.input_vqa)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_vqa = pipeline(task=Tasks.visual_question_answering) | |||
print(pipeline_vqa(self.input_vqa)) | |||
if __name__ == '__main__': | |||
unittest.main() |