diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ca2e21d6..13656fad 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -218,7 +218,7 @@ class Preprocessors(object): # multi-modal preprocessor ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' - mplug_visual_question_answering = 'mplug-visual-question-answering' + mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' class Metrics(object): diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 6c40a3da..112b3a58 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -9,8 +9,7 @@ if TYPE_CHECKING: from .gemm import GEMMForMultiModalEmbedding from .diffusion import DiffusionForTextToImageSynthesis from .mmr import VideoCLIPForMultiModalEmbedding - from .mplug_for_visual_question_answering import \ - MPlugForVisualQuestionAnswering + from .mplug_for_all_tasks import MPlugForAllTasks from .ofa_for_all_tasks import OfaForAllTasks from .ofa_for_text_to_image_synthesis_model import \ OfaForTextToImageSynthesis @@ -21,8 +20,7 @@ else: 'diffusion': ['DiffusionForTextToImageSynthesis'], 'gemm': ['GEMMForMultiModalEmbedding'], 'mmr': ['VideoCLIPForMultiModalEmbedding'], - 'mplug_for_visual_question_answering': - ['MPlugForVisualQuestionAnswering'], + 'mplug_for_all_tasks': ['MPlugForAllTasks'], 'ofa_for_all_tasks': ['OfaForAllTasks'], 'ofa_for_text_to_image_synthesis_model': ['OfaForTextToImageSynthesis'] diff --git a/modelscope/models/multi_modal/mplug/__init__.py b/modelscope/models/multi_modal/mplug/__init__.py index a145fc0c..955c87e2 100644 --- a/modelscope/models/multi_modal/mplug/__init__.py +++ b/modelscope/models/multi_modal/mplug/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. from .configuration_mplug import MPlugConfig -from .modeling_mplug import CONFIG_NAME, MPlugForVisualQuestionAnswering +from .modeling_mplug import CONFIG_NAME, MPlug diff --git a/modelscope/models/multi_modal/mplug/configuration_mplug.py b/modelscope/models/multi_modal/mplug/configuration_mplug.py index 6b2914c4..c275ed15 100644 --- a/modelscope/models/multi_modal/mplug/configuration_mplug.py +++ b/modelscope/models/multi_modal/mplug/configuration_mplug.py @@ -15,14 +15,14 @@ # limitations under the License. """ MPLUG model configuration """ import os -from collections import OrderedDict -from typing import Any, Dict, Mapping, Union +from typing import Any, Dict, Union import yaml from transformers import PretrainedConfig -from transformers.onnx import OnnxConfig from transformers.utils import logging +from modelscope.utils.constant import Tasks + logger = logging.get_logger(__name__) @@ -32,6 +32,7 @@ class MPlugConfig(PretrainedConfig): def __init__( self, + task=Tasks.visual_question_answering, bert_config='config_bert.json', image_res=504, batch_size_train=128, @@ -64,7 +65,9 @@ class MPlugConfig(PretrainedConfig): clip_transformer_heads=12, clip_transformer_layers=12, **kwargs): + super().__init__(**kwargs) + self.task = task self.bert_config = bert_config self.image_res = image_res self.batch_size_train = batch_size_train @@ -103,23 +106,3 @@ class MPlugConfig(PretrainedConfig): with open(yaml_file, 'r') as reader: config_dict = yaml.load(reader, Loader=yaml.Loader) return cls(**config_dict) - - -class MPlugOnnxConfig(OnnxConfig): - - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - return OrderedDict([ - ('input_ids', { - 0: 'batch', - 1: 'sequence' - }), - ('attention_mask', { - 0: 'batch', - 1: 'sequence' - }), - ('token_type_ids', { - 0: 'batch', - 1: 'sequence' - }), - ]) diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index 79fab718..50622cc0 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1725,7 +1725,116 @@ class BertLMHeadModel(BertPreTrainedModel): return reordered_past -class MPlugForVisualQuestionAnswering(PreTrainedModel): +class BertPrefixModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint='bert-base-uncased', + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=True, + reduction='mean', + soft_labels=None, + alpha=0, + return_logits=False, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, : + -1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=1) * soft_labels, + dim=-1) + loss_distill = loss_distill[labels != -100].mean() + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores, ) + outputs[2:] + return ((lm_loss, ) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class MPlug(PreTrainedModel): config_class = MPlugConfig def __init__(self, config): @@ -1739,16 +1848,19 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): self.config_encoder, add_pooling_layer=False) self.fusion_encoder = FusionModel( self.config_fusion, add_pooling_layer=False) - self.text_decoder = BertLMHeadModel(self.config_decoder) - self.init_distill(config) - self.beam_generator = TextGenerator(config, self.text_decoder) @classmethod def from_pretrained(cls, model_dir, load_checkpoint=True): - config = MPlugConfig.from_yaml_file( + from modelscope.utils.constant import Tasks + + task_mapping = { + Tasks.visual_question_answering: MPlugForVisualQuestionAnswering, + Tasks.image_captioning: MPLUGForImageCaption + } + config = cls.config_class.from_yaml_file( os.path.join(model_dir, CONFIG_NAME)) config.model_dir = model_dir - model = cls(config) + model = task_mapping[config.task](config) if load_checkpoint: checkpoint_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) @@ -1803,6 +1915,161 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): clip_model.visual.positional_embedding = pos_embed return clip_model + def forward(self, *args, **kwargs): + raise NotImplementedError + + def module_setting(self, config): + bert_config_path = os.path.join(config.model_dir, config.bert_config) + self.config_encoder = BertConfig.from_json_file(bert_config_path) + self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers + self.config_fusion = BertConfig.from_json_file(bert_config_path) + self.config_decoder = BertConfig.from_json_file(bert_config_path) + self.config_decoder.add_cross_attention = True + self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers + self.large = False + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) + self.large = True + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * ( + 1. - self.momentum) + + def generation(self, question_states, question_atts, out_size=1): + encoder_inputs = [question_states, question_atts] + topk_ids, topk_scores = self.beam_generator.translate_batch( + encoder_inputs, out_size=out_size) + return topk_ids, topk_scores + + @staticmethod + def _tile(x, dim, n_tile): + import numpy as np + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate( + [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + def rank_answer(self, question_states, question_atts, answer_ids, + answer_atts, k): + + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.text_decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction='none') + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax( + logits, dim=1).index_select( + dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100) + + # repeat encoder's output for top-k answers + question_states = self._tile(question_states, 0, k) + question_atts = self._tile(question_atts, 0, k) + + output = self.text_decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction='none') + + answer_loss = output.loss + answer_loss = answer_loss.view(input_ids.size(0), -1) + + # topk_prob: first token probability + topk_probs = topk_probs.view(-1, 1) + log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) + + # re-calculate log probabilities for the answer sequences using chain rule + log_probs_sum = log_probs.sum(1) + log_probs_sum = log_probs_sum.view(num_ques, k) + + topk_probs = F.softmax(log_probs_sum, dim=-1) + # get top-k after re-ranking + topk_probs, rerank_id = topk_probs.topk(k, dim=1) + topk_ids = torch.gather(topk_ids, 1, rerank_id) + + return topk_ids, topk_probs + + +class MPlugForVisualQuestionAnswering(MPlug): + + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertLMHeadModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) + self.init_distill(config) + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = self._initialize_clip(config) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder_m = BertLMHeadModel(self.config_decoder) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + if self.config_encoder.hidden_size != config.vision_width: + self.visn_fc_m = nn.Linear(config.vision_width, + self.config_encoder.hidden_size) + self.visn_layer_norm_m = nn.LayerNorm( + self.config_encoder.hidden_size, eps=1e-12) + self.dropout_m = nn.Dropout( + self.config_encoder.hidden_dropout_prob) + self.model_pairs.extend( + [[self.visn_fc, self.visn_fc_m], + [self.visn_layer_norm, self.visn_layer_norm_m]]) + self.copy_params() + self.momentum = 0.995 + def forward(self, image, question, @@ -1935,145 +2202,110 @@ class MPlugForVisualQuestionAnswering(PreTrainedModel): merge_text_attention) return topk_ids, topk_probs - def module_setting(self, config): - bert_config_path = os.path.join(config.model_dir, config.bert_config) - self.config_encoder = BertConfig.from_json_file(bert_config_path) - self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers - self.config_fusion = BertConfig.from_json_file(bert_config_path) - self.config_decoder = BertConfig.from_json_file(bert_config_path) - self.config_decoder.add_cross_attention = True - self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers - self.large = False - if self.config_encoder.hidden_size != config.vision_width: - self.visn_fc = nn.Linear(config.vision_width, - self.config_encoder.hidden_size) - self.visn_layer_norm = nn.LayerNorm( - self.config_encoder.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(self.config_encoder.hidden_dropout_prob) - self.large = True - - def init_distill(self, config): - self.distill = config.distill - if self.distill: - self.visual_encoder_m = self._initialize_clip(config) - self.text_encoder_m = BertModel( - self.config_encoder, add_pooling_layer=False) - self.fusion_encoder_m = FusionModel( - self.config_fusion, add_pooling_layer=False) - self.text_decoder_m = BertLMHeadModel(self.config_decoder) - self.model_pairs = [ - [self.visual_encoder, self.visual_encoder_m], - [self.text_encoder, self.text_encoder_m], - [self.text_decoder, self.text_decoder_m], - ] - if self.config_encoder.hidden_size != config.vision_width: - self.visn_fc_m = nn.Linear(config.vision_width, - self.config_encoder.hidden_size) - self.visn_layer_norm_m = nn.LayerNorm( - self.config_encoder.hidden_size, eps=1e-12) - self.dropout_m = nn.Dropout( - self.config_encoder.hidden_dropout_prob) - self.model_pairs.extend( - [[self.visn_fc, self.visn_fc_m], - [self.visn_layer_norm, self.visn_layer_norm_m]]) - self.copy_params() - self.momentum = 0.995 - - @torch.no_grad() - def copy_params(self): - for model_pair in self.model_pairs: - for param, param_m in zip(model_pair[0].parameters(), - model_pair[1].parameters()): - param_m.data.copy_(param.data) # initialize - param_m.requires_grad = False # not update by gradient - - @torch.no_grad() - def _momentum_update(self): - for model_pair in self.model_pairs: - for param, param_m in zip(model_pair[0].parameters(), - model_pair[1].parameters()): - param_m.data = param_m.data * self.momentum + param.data * ( - 1. - self.momentum) - - def generation(self, question_states, question_atts): - encoder_inputs = [question_states, question_atts] - topk_ids, topk_scores = self.beam_generator.translate_batch( - encoder_inputs) - return topk_ids, topk_scores - - @staticmethod - def _tile(x, dim, n_tile): - import numpy as np - init_dim = x.size(dim) - repeat_idx = [1] * x.dim() - repeat_idx[dim] = n_tile - x = x.repeat(*(repeat_idx)) - order_index = torch.LongTensor( - np.concatenate( - [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) - return torch.index_select(x, dim, order_index.to(x.device)) - - def rank_answer(self, question_states, question_atts, answer_ids, - answer_atts, k): - - num_ques = question_states.size(0) - start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token - start_output = self.text_decoder( - start_ids, - encoder_hidden_states=question_states, - encoder_attention_mask=question_atts, - return_dict=True, - reduction='none') - logits = start_output.logits[:, 0, :] # first token's logit +class MPLUGForImageCaption(MPlug): - # topk_probs: top-k probability - # topk_ids: [num_question, k] - answer_first_token = answer_ids[:, 1] - prob_first_token = F.softmax( - logits, dim=1).index_select( - dim=1, index=answer_first_token) - topk_probs, topk_ids = prob_first_token.topk(k, dim=1) - - # answer input: [num_question*k, answer_len] - input_ids = [] - input_atts = [] - for b, topk_id in enumerate(topk_ids): - input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) - input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) - input_ids = torch.cat(input_ids, dim=0) - input_atts = torch.cat(input_atts, dim=0) - - targets_ids = input_ids.masked_fill( - input_ids == self.tokenizer.pad_token_id, -100) - - # repeat encoder's output for top-k answers - question_states = self._tile(question_states, 0, k) - question_atts = self._tile(question_atts, 0, k) + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertPrefixModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) - output = self.text_decoder( - input_ids, - attention_mask=input_atts, - encoder_hidden_states=question_states, - encoder_attention_mask=question_atts, - labels=targets_ids, - return_dict=True, - reduction='none') + def beam_search(self, + image, + question, + answer=None, + train=True, + out_size=5): + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + image_output, question_output = fusion_output + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat([image_atts, question.attention_mask], + 1) + topk_ids, topk_probs = self.generation( + question_output, merge_text_attention, out_size=out_size) + return topk_ids, topk_probs - answer_loss = output.loss - answer_loss = answer_loss.view(input_ids.size(0), -1) + def forward(self, + image, + question, + answer=None, + train=True, + out_size=5, + scst=False): + if (scst): + return self.beam_search( + image, question, answer, train=True, out_size=out_size) + image = image.to(dtype=next(self.parameters()).dtype) + image_embeds = self.visual_encoder.visual(image, skip_last_layer=True) + if self.large: + image_embeds = self.dropout( + self.visn_layer_norm(self.visn_fc(image_embeds))) + image_atts = torch.ones( + image_embeds.size()[:-1], dtype=torch.long).to(image.device) - # topk_prob: first token probability - topk_probs = topk_probs.view(-1, 1) - log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1) + if train: + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) - # re-calculate log probabilities for the answer sequences using chain rule - log_probs_sum = log_probs.sum(1) - log_probs_sum = log_probs_sum.view(num_ques, k) + image_output, question_output = fusion_output - topk_probs = F.softmax(log_probs_sum, dim=-1) - # get top-k after re-ranking - topk_probs, rerank_id = topk_probs.topk(k, dim=1) - topk_ids = torch.gather(topk_ids, 1, rerank_id) + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) - return topk_ids, topk_probs + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_output, + encoder_attention_mask=merge_text_attention, + labels=answer_targets, + return_dict=True, + reduction='none') + loss = answer_output.loss + return loss + else: + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=False) + image_output, question_output = fusion_output + question_output = torch.cat([image_output, question_output], 1) + merge_text_attention = torch.cat( + [image_atts, question.attention_mask], 1) + topk_ids, topk_probs = self.generation(question_output, + merge_text_attention) + return topk_ids, topk_probs diff --git a/modelscope/models/multi_modal/mplug_for_visual_question_answering.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py similarity index 60% rename from modelscope/models/multi_modal/mplug_for_visual_question_answering.py rename to modelscope/models/multi_modal/mplug_for_all_tasks.py index 88875fda..bb5a9c46 100644 --- a/modelscope/models/multi_modal/mplug_for_visual_question_answering.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -6,12 +6,13 @@ from modelscope.models.base import Tensor from modelscope.models.builder import MODELS from modelscope.utils.constant import Tasks -__all__ = ['MPlugForVisualQuestionAnswering'] +__all__ = ['MPlugForAllTasks'] @MODELS.register_module( Tasks.visual_question_answering, module_name=Models.mplug) -class MPlugForVisualQuestionAnswering(TorchModel): +@MODELS.register_module(Tasks.image_captioning, module_name=Models.mplug) +class MPlugForAllTasks(TorchModel): def __init__(self, model_dir: str, *args, **kwargs): """initialize the mplug model from the `model_dir` path. @@ -20,8 +21,8 @@ class MPlugForVisualQuestionAnswering(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - from modelscope.models.multi_modal.mplug import MPlugForVisualQuestionAnswering - self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) + from modelscope.models.multi_modal.mplug import MPlug + self.model = MPlug.from_pretrained(model_dir) self.tokenizer = self.model.tokenizer def train(self): @@ -44,4 +45,13 @@ class MPlugForVisualQuestionAnswering(TorchModel): } """ - return self.model(**input)[0] + topk_ids, _ = self.model(**input) + replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), + ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), + ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + + pred_string = self.tokenizer.decode(topk_ids[0][0]) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string = pred_string.strip() + return pred_string diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 2028e7dc..99cccee1 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -1,11 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Optional, Union +import torch + from modelscope.metainfo import Pipelines -from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks +from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, + Preprocessor) from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -35,9 +39,19 @@ class ImageCaptioningPipeline(Pipeline): else: raise NotImplementedError pipe_model.model.eval() - if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): - preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + elif isinstance(pipe_model, MPlugForAllTasks): + preprocessor = MPlugPreprocessor(pipe_model.model_dir) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - return inputs + if isinstance(self.model, OfaForAllTasks): + return inputs + return {OutputKeys.CAPTION: inputs} diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py index 9c694500..b2442a3e 100644 --- a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -5,13 +5,12 @@ import torch from modelscope.metainfo import Pipelines from modelscope.models import Model -from modelscope.models.multi_modal import (MPlugForVisualQuestionAnswering, - OfaForAllTasks) +from modelscope.models.multi_modal import MPlugForAllTasks, OfaForAllTasks from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (MPlugVisualQuestionAnsweringPreprocessor, - OfaPreprocessor) +from modelscope.preprocessors import (MPlugPreprocessor, OfaPreprocessor, + Preprocessor) from modelscope.utils.constant import Tasks __all__ = ['VisualQuestionAnsweringPipeline'] @@ -23,9 +22,8 @@ __all__ = ['VisualQuestionAnsweringPipeline'] class VisualQuestionAnsweringPipeline(Pipeline): def __init__(self, - model: Union[MPlugForVisualQuestionAnswering, str], - preprocessor: Optional[ - MPlugVisualQuestionAnsweringPreprocessor] = None, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, **kwargs): """use `model` and `preprocessor` to create a visual question answering pipeline for prediction @@ -35,18 +33,12 @@ class VisualQuestionAnsweringPipeline(Pipeline): """ model = model if isinstance(model, Model) else Model.from_pretrained(model) - self.tokenizer = None if preprocessor is None: if isinstance(model, OfaForAllTasks): preprocessor = OfaPreprocessor(model.model_dir) - elif isinstance(model, MPlugForVisualQuestionAnswering): - preprocessor = MPlugVisualQuestionAnsweringPreprocessor( - model.model_dir) - if isinstance(model, MPlugForVisualQuestionAnswering): - model.eval() - self.tokenizer = model.tokenizer - else: - model.model.eval() + elif isinstance(model, MPlugForAllTasks): + preprocessor = MPlugPreprocessor(model.model_dir) + model.model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) def forward(self, inputs: Dict[str, Any], @@ -64,14 +56,6 @@ class VisualQuestionAnsweringPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ - if self.tokenizer is None: + if isinstance(self.model, OfaForAllTasks): return inputs - replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), - ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), - ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) - - pred_string = self.tokenizer.decode(inputs[0][0]) - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string.strip() - return {OutputKeys.TEXT: pred_string} + return {OutputKeys.TEXT: inputs} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index c5c6a33c..0328b91a 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -14,8 +14,7 @@ if TYPE_CHECKING: ImageInstanceSegmentationPreprocessor, ImageDenoisePreprocessor) from .kws import WavToLists - from .multi_modal import (OfaPreprocessor, - MPlugVisualQuestionAnsweringPreprocessor) + from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) from .nlp import (Tokenize, SequenceClassificationPreprocessor, TextGenerationPreprocessor, TokenClassificationPreprocessor, @@ -42,8 +41,7 @@ else: 'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor' ], 'kws': ['WavToLists'], - 'multi_modal': - ['OfaPreprocessor', 'MPlugVisualQuestionAnsweringPreprocessor'], + 'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], 'nlp': [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 7665e8b7..56b10c3a 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -17,7 +17,7 @@ from .ofa.utils.collate import collate_fn __all__ = [ 'OfaPreprocessor', - 'MPlugVisualQuestionAnsweringPreprocessor', + 'MPlugPreprocessor', ] @@ -88,39 +88,55 @@ class OfaPreprocessor(Preprocessor): @PREPROCESSORS.register_module( - Fields.multi_modal, - module_name=Preprocessors.mplug_visual_question_answering) -class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): + Fields.multi_modal, module_name=Preprocessors.mplug_tasks_preprocessor) +class MPlugPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data via 'bert-base-uncased' tokenizer and configuration - - """ - from transformers import BertTokenizer - from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig - super().__init__(*args, **kwargs) + self.model_dir = model_dir - # tokenizer - self.tokenizer = BertTokenizer.from_pretrained( - osp.join(model_dir, ModelFile.VOCAB_FILE)) + self._tokenizer = None + self._patch_resize_transform = None - # load configuration - config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) + @property + def tokenizer(self): + from transformers import BertTokenizer - # Initialize transform - from torchvision import transforms - mean = (0.48145466, 0.4578275, 0.40821073) - std = (0.26862954, 0.26130258, 0.27577711) + if self._tokenizer is None: + self._tokenizer = BertTokenizer.from_pretrained(self.model_dir) + return self._tokenizer + + @property + def patch_resize_transform(self): + if self._patch_resize_transform is None: + from torchvision import transforms + from modelscope.models.multi_modal.mplug import CONFIG_NAME, MPlugConfig + + config = MPlugConfig.from_yaml_file( + osp.join(self.model_dir, CONFIG_NAME)) + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + + self._patch_resize_transform = transforms.Compose([ + transforms.Resize((config.image_res, config.image_res), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + return self._patch_resize_transform + + def __call__(self, *args, **kwargs): + call_mapping = { + Tasks.visual_question_answering: self.vqa_call, + Tasks.image_captioning: self.caption_call + } - self.patch_resize_transform = transforms.Compose([ - transforms.Resize((config.image_res, config.image_res), - interpolation=Image.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) + self.cfg = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)) + return call_mapping[self.cfg.task](*args, **kwargs) - def __call__(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: + def vqa_call(self, data: Union[tuple, Dict[str, Any]]) -> Dict[str, Any]: image: Image.Image = data[0] if isinstance(data, tuple) else data['image'] question: str = data[1] if isinstance(data, @@ -133,3 +149,19 @@ class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): return_tensors='pt') return {'image': image, 'question': question, 'train': False} + + def caption_call( + self, data: Union[Image.Image, tuple, + Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(data, Image.Image): + image = data + elif isinstance(data, tuple): + image = data[0] + else: + image = data['image'] + image = image.convert('RGB') + image = self.patch_resize_transform(image) + image = torch.stack([image], dim=0) + question = self.tokenizer('', return_tensors='pt') + + return {'image': image, 'question': question, 'train': False} diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py new file mode 100644 index 00000000..4b8a813a --- /dev/null +++ b/tests/pipelines/test_mplug_tasks.py @@ -0,0 +1,59 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from PIL import Image + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MplugTasksTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_image-captioning_coco_base_en') + pipeline_caption = pipeline( + task=Tasks.image_captioning, + model=model, + ) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption({'image': image}) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_image_captioning_with_name(self): + pipeline_caption = pipeline( + Tasks.image_captioning, + model='damo/mplug_image-captioning_coco_base_en') + image = Image.open('data/test/images/image_mplug_vqa.jpg') + result = pipeline_caption({'image': image}) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_question_answering_with_model(self): + model = Model.from_pretrained( + 'damo/mplug_visual-question-answering_coco_large_en') + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + question = 'What is the woman doing?' + input = {'image': image, 'question': question} + result = pipeline_vqa(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_visual_question_answering_with_name(self): + model = 'damo/mplug_visual-question-answering_coco_large_en' + pipeline_vqa = pipeline(Tasks.visual_question_answering, model=model) + image = Image.open('data/test/images/image_mplug_vqa.jpg') + question = 'What is the woman doing?' + input = {'image': image, 'question': question} + result = pipeline_vqa(input) + print(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_visual_question_answering.py b/tests/pipelines/test_visual_question_answering.py deleted file mode 100644 index 748a86b9..00000000 --- a/tests/pipelines/test_visual_question_answering.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import unittest - -from PIL import Image - -from modelscope.hub.snapshot_download import snapshot_download -from modelscope.models import Model -from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering -from modelscope.pipelines import pipeline -from modelscope.pipelines.multi_modal import VisualQuestionAnsweringPipeline -from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor -from modelscope.utils.constant import Tasks -from modelscope.utils.test_utils import test_level - - -class VisualQuestionAnsweringTest(unittest.TestCase): - - def setUp(self): - self.model_id = 'damo/mplug_visual-question-answering_coco_large_en' - self.input_vqa = { - 'image': Image.open('data/test/images/image_mplug_vqa.jpg'), - 'question': 'What is the woman doing?', - } - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): - cache_path = snapshot_download(self.model_id) - preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path) - model = MPlugForVisualQuestionAnswering(cache_path) - pipeline1 = VisualQuestionAnsweringPipeline( - model, preprocessor=preprocessor) - pipeline2 = pipeline( - Tasks.visual_question_answering, - model=model, - preprocessor=preprocessor) - print(f"question: {self.input_vqa['question']}") - print(f'pipeline1: {pipeline1(self.input_vqa)}') - print(f'pipeline2: {pipeline2(self.input_vqa)}') - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id) - preprocessor = MPlugVisualQuestionAnsweringPreprocessor( - model.model_dir) - pipeline_vqa = pipeline( - task=Tasks.visual_question_answering, - model=model, - preprocessor=preprocessor) - print(pipeline_vqa(self.input_vqa)) - - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_name(self): - pipeline_vqa = pipeline( - Tasks.visual_question_answering, model=self.model_id) - print(pipeline_vqa(self.input_vqa)) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_with_default_model(self): - pipeline_vqa = pipeline(task=Tasks.visual_question_answering) - print(pipeline_vqa(self.input_vqa)) - - -if __name__ == '__main__': - unittest.main()