Browse Source

add 6 ofa tasks

master
yichang.zyc 3 years ago
parent
commit
d1d2c96dd9
41 changed files with 2290 additions and 270 deletions
  1. +3
    -0
      data/test/images/image_classification.png
  2. +3
    -0
      data/test/images/visual_grounding.png
  3. +3
    -0
      data/test/images/visual_question_answering.png
  4. +4
    -0
      modelscope/metainfo.py
  5. +1
    -2
      modelscope/models/multi_modal/__init__.py
  6. +0
    -86
      modelscope/models/multi_modal/image_captioning_model.py
  7. +4
    -43
      modelscope/models/multi_modal/ofa/generate/sequence_generator.py
  8. +13
    -0
      modelscope/models/multi_modal/ofa/utils/constant.py
  9. +19
    -0
      modelscope/models/multi_modal/ofa/utils/utils.py
  10. +259
    -0
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  11. +0
    -53
      modelscope/models/multi_modal/ofa_for_image_captioning_model.py
  12. +2
    -1
      modelscope/pipelines/cv/__init__.py
  13. +32
    -6
      modelscope/pipelines/cv/image_classification_pipeline.py
  14. +6
    -2
      modelscope/pipelines/multi_modal/__init__.py
  15. +7
    -5
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py
  16. +42
    -0
      modelscope/pipelines/multi_modal/visual_entailment_pipeline.py
  17. +42
    -0
      modelscope/pipelines/multi_modal/visual_grounding_pipeline.py
  18. +11
    -5
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py
  19. +4
    -0
      modelscope/pipelines/nlp/__init__.py
  20. +42
    -0
      modelscope/pipelines/nlp/summarization_pipeline.py
  21. +42
    -0
      modelscope/pipelines/nlp/text_classification_pipeline.py
  22. +3
    -5
      modelscope/preprocessors/__init__.py
  23. +25
    -39
      modelscope/preprocessors/multi_modal.py
  24. +8
    -0
      modelscope/preprocessors/ofa/__init__.py
  25. +117
    -0
      modelscope/preprocessors/ofa/base.py
  26. +42
    -0
      modelscope/preprocessors/ofa/image_captioning.py
  27. +43
    -0
      modelscope/preprocessors/ofa/image_classification.py
  28. +37
    -0
      modelscope/preprocessors/ofa/summarization.py
  29. +38
    -0
      modelscope/preprocessors/ofa/text_classification.py
  30. +0
    -0
      modelscope/preprocessors/ofa/utils/__init__.py
  31. +109
    -0
      modelscope/preprocessors/ofa/utils/collate.py
  32. +42
    -0
      modelscope/preprocessors/ofa/utils/random_help.py
  33. +557
    -0
      modelscope/preprocessors/ofa/utils/transforms.py
  34. +357
    -0
      modelscope/preprocessors/ofa/utils/vision_helper.py
  35. +62
    -0
      modelscope/preprocessors/ofa/visual_entailment.py
  36. +50
    -0
      modelscope/preprocessors/ofa/visual_grounding.py
  37. +52
    -0
      modelscope/preprocessors/ofa/visual_question_answering.py
  38. +1
    -0
      modelscope/utils/constant.py
  39. +29
    -0
      modelscope/utils/trie.py
  40. +0
    -23
      tests/pipelines/test_image_captioning.py
  41. +179
    -0
      tests/pipelines/test_ofa_tasks.py

+ 3
- 0
data/test/images/image_classification.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8bdb9627c3a40897e84ee186b2a959f272790571644224e1d2efca443f867e12
size 202823

+ 3
- 0
data/test/images/visual_grounding.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b89734b9c9d89342e58fbe406d3b9bdc8e07447cb170a4ae2743000471fc969
size 23069

+ 3
- 0
data/test/images/visual_question_answering.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d53e9fbdd129b234dcbec9b9fe6a15a0e05820e802a873f95955574267bbd2ff
size 121141

+ 4
- 0
modelscope/metainfo.py View File

@@ -69,6 +69,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
@@ -92,6 +93,7 @@ class Pipelines(object):
text_generation = 'text-generation'
sentiment_analysis = 'sentiment-analysis'
sentiment_classification = 'sentiment-classification'
text_classification = 'text-classification'
fill_mask = 'fill-mask'
csanmt_translation = 'csanmt-translation'
nli = 'nli'
@@ -113,6 +115,8 @@ class Pipelines(object):
multi_modal_embedding = 'multi-modal-embedding'
generative_multi_modal_embedding = 'generative-multi-modal-embedding'
visual_question_answering = 'visual-question-answering'
visual_grounding = 'visual-grounding'
visual_entailment = 'visual-entailment'
text_to_image_synthesis = 'text-to-image-synthesis'
video_multi_modal_embedding = 'video-multi-modal-embedding'



+ 1
- 2
modelscope/models/multi_modal/__init__.py View File

@@ -11,7 +11,6 @@ if TYPE_CHECKING:
from .mmr import VideoCLIPForMultiModalEmbedding
from .mplug_for_visual_question_answering import \
MPlugForVisualQuestionAnswering
from .ofa_for_image_captioning_model import OfaForImageCaptioning

else:
_import_structure = {
@@ -21,7 +20,7 @@ else:
'mmr': ['VideoCLIPForMultiModalEmbedding'],
'mplug_for_visual_question_answering':
['MPlugForVisualQuestionAnswering'],
'ofa_for_image_captioning_model': ['OfaForImageCaptioning']
'ofa_for_all_tasks': ['OfaForAllTasks']
}

import sys


+ 0
- 86
modelscope/models/multi_modal/image_captioning_model.py View File

@@ -1,86 +0,0 @@
import os.path as osp
from typing import Any, Dict

import torch.cuda
from PIL import Image

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['OfaForImageCaptioning']


@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa)
class OfaForImageCaptioning(Model):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
ckpt_name = ModelFile.TORCH_MODEL_FILE
local_model = osp.join(model_dir, ckpt_name)
bpe_dir = model_dir
# turn on cuda if GPU is available
from fairseq import checkpoint_utils, tasks, utils
from ofa.tasks.mm_tasks import CaptionTask
from ofa.utils.eval_utils import eval_caption
self.eval_caption = eval_caption
tasks.register_task('caption', CaptionTask)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and torch.cuda.is_available()\
else False
overrides = {
'bpe_dir': bpe_dir,
'eval_cider': False,
'beam': 5,
'max_len_b': 16,
'no_repeat_ngram_size': 3,
'seed': 7
}
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides)
# Move models to GPU
for model in models:
model.eval()
model.to(self._device)
if self.use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
# Initialize generator
self.generator = task.build_generator(models, cfg.generation)

# Initialize transform
from torchvision import transforms
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(cfg.task.patch_image_size, cfg.task.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.task = task

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
import fairseq.utils
if torch.cuda.is_available():
input = fairseq.utils.move_to_cuda(input, device=self._device)
results, _ = self.eval_caption(self.task, self.generator, self.models,
input)
from modelscope.outputs import OutputKeys
return {
'image_id': results[0]['image_id'],
OutputKeys.CAPTION: results[0][OutputKeys.CAPTION]
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# What should we do here ?
return inputs

+ 4
- 43
modelscope/models/multi_modal/ofa/generate/sequence_generator.py View File

@@ -194,13 +194,6 @@ class SequenceGenerator(nn.Module):
bos_token: Optional[int] = None,
):
model = EnsembleModel(models)
# incremental_states = torch.jit.annotate(
# List[Dict[str, Dict[str, Optional[Tensor]]]],
# [
# torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
# for i in range(model.models_size)
# ],
# )
incremental_states = torch.jit.annotate(
List[Tuple[Tuple[torch.Tensor]]],
[
@@ -208,8 +201,6 @@ class SequenceGenerator(nn.Module):
for i in range(model.models_size)
],
)
# print("incremental_states",incremental_states)
# print("incremental_states[0]",incremental_states[0])
net_input = sample['net_input']

if 'src_tokens' in net_input:
@@ -281,7 +272,6 @@ class SequenceGenerator(nn.Module):
tokens = (torch.zeros(bsz * beam_size,
max_len + 2).to(src_tokens).long().fill_(
self.pad)) # +2 for eos and pad
# tokens[:, 0] = self.eos if bos_token is None else bos_token
tokens[:, 0] = self.bos
attn: Optional[Tensor] = None

@@ -335,7 +325,7 @@ class SequenceGenerator(nn.Module):
corr.unsqueeze(-1) * beam_size)
original_batch_idxs = original_batch_idxs[batch_idxs]
model.reorder_incremental_state(incremental_states,
reorder_state) # todo
reorder_state)
encoder_outs = model.reorder_encoder_out(
encoder_outs, reorder_state)

@@ -479,7 +469,6 @@ class SequenceGenerator(nn.Module):
batch_mask = torch.ones(
bsz, dtype=torch.bool, device=cand_indices.device)
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(
bsz, device=cand_indices.device).masked_select(batch_mask)

@@ -833,7 +822,7 @@ class EnsembleModel(nn.Module):

# decode each model
if self.has_incremental_states():
decoder_out = model.decoder.forward( # todo 模型输入不同
decoder_out = model.decoder.forward(
input_ids=tokens,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
@@ -846,7 +835,7 @@ class EnsembleModel(nn.Module):
else:
if hasattr(model, 'decoder'):
# decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out)
decoder_out = model.decoder.forward( # todo 模型输入不同
decoder_out = model.decoder.forward(
input_ids=tokens,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
@@ -855,32 +844,9 @@ class EnsembleModel(nn.Module):
src_pos_embed=src_pos_embed)
else:
decoder_out = model.forward(tokens)
# print('#### decoder_out ####', decoder_out)
# print('#### decoder_out ####', decoder_out.keys())
# for k,v in decoder_out.items():
# print(k)
# if isinstance(v, Tensor):
# print(v.shape)
# elif k == "past_key_values":
# print(len(v))
# print([v[0][i].shape for i in range(len(v[0]))])
# else:
# print(len(v))
# print([v[i].shape for i in range(len(v))])

attn: Optional[Tensor] = None
decoder_len = len(decoder_out)
# if decoder_len > 1 and decoder_out[1] is not None:
# if isinstance(decoder_out[1], Tensor):
# attn = decoder_out[1]
# else:
# attn_holder = decoder_out[1]["attn"]
# if isinstance(attn_holder, Tensor):
# attn = attn_holder
# elif attn_holder is not None:
# attn = attn_holder[0]
# if attn is not None:
# attn = attn[:, -1, :]

if 'cross_attentions' in decoder_out:
attn = decoder_out['cross_attentions'][-1].transpose(1, 0)
@@ -888,11 +854,6 @@ class EnsembleModel(nn.Module):
if attn is not None:
attn = attn[:, -1, :]

# decoder_out_tuple = (
# decoder_out[0][:, -1:, :].div_(temperature),
# None if decoder_len <= 1 else decoder_out[1],
# )

decoder_out_tuple = (
decoder_out[0][:, -1:, :].div_(temperature),
None if decoder_len <= 1 else attn,
@@ -993,5 +954,5 @@ class EnsembleModel(nn.Module):
if not self.has_incremental_states():
return
for i, model in enumerate(self.models):
model.decoder.reorder_incremental_state_scripting( # todo
model.decoder.reorder_incremental_state_scripting(
incremental_states[i], new_order)

+ 13
- 0
modelscope/models/multi_modal/ofa/utils/constant.py View File

@@ -0,0 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks

OFA_TASK_KEY_MAPPING = {
Tasks.image_captioning: OutputKeys.CAPTION,
Tasks.summarization: OutputKeys.TEXT,
Tasks.visual_question_answering: OutputKeys.TEXT,
Tasks.visual_grounding: OutputKeys.BOXES,
Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS),
Tasks.image_classification: OutputKeys.LABELS,
Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS),
}

+ 19
- 0
modelscope/models/multi_modal/ofa/utils/utils.py View File

@@ -0,0 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional

import torch


def expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
r"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
return expanded_mask.masked_fill(expanded_mask.bool(),
torch.finfo(dtype).min)

+ 259
- 0
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -0,0 +1,259 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from os import path as osp
from typing import Any, Dict

import json
import torch.cuda
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models.base import Model, Tensor
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.ofa.utils.collate import collate_tokens
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.trie import Trie
from .ofa import OFAModel, OFATokenizer
from .ofa.generate import sequence_generator as sg
from .ofa.generate.utils import move_to_device
from .ofa.utils.constant import OFA_TASK_KEY_MAPPING, Tasks
from .ofa.utils.utils import expand_mask

__all__ = ['OfaForAllTasks']


@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa)
@MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa)
@MODELS.register_module(
Tasks.visual_question_answering, module_name=Models.ofa)
@MODELS.register_module(Tasks.visual_entailment, module_name=Models.ofa)
@MODELS.register_module(Tasks.image_classification, module_name=Models.ofa)
@MODELS.register_module(Tasks.summarization, module_name=Models.ofa)
@MODELS.register_module(Tasks.text_classification, module_name=Models.ofa)
class OfaForAllTasks(Model):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
model = OFAModel.from_pretrained(model_dir)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.model = model.module if hasattr(model, 'module') else model
self.tokenizer = OFATokenizer.from_pretrained(model_dir)
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
self.batch_size = self.cfg.model.get('batch_size', 1)
self.val_batch_size = self.cfg.model.get('valid_batch_size',
self.batch_size)
self.gen_type = self.cfg.model.get('gen_type', 'generation')
assert self.gen_type in ['generation', 'traverse'], \
'model.gen_type must be in ["generation", "traverse"]'
self._device = torch.device('cuda') if torch.cuda.is_available() \
else torch.device('cpu')
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id
]).to(self._device)
self.index2ans = {}
self.ans2label_dict = {}
self.load_ans2label()
# Initialize generator
sg_args = {
'tokenizer': self.tokenizer,
'beam_size': 5,
'max_len_b': 16,
'min_len': 1,
'no_repeat_ngram_size': 3,
'constraint_range': None
}
if hasattr(self.cfg.model, 'beam_search'):
sg_args.update(self.cfg.model.beam_search)
if len(self.ans2label_dict) > 0:
self.constraint_trie = Trie(self.tokenizer.eos_token_id)
self.val_ans_l = []
self.val_masks_l = []
self.build_trie()
sg_args['constraint_trie'] = self.constraint_trie
self.model.to(self._device)
self.generator = sg.SequenceGenerator(**sg_args)
inference_d = {
'generation': self._text_gen_inference,
'traverse': self._traverse_inference,
}
self.task_inference_mapping = {
Tasks.image_captioning: self._text_gen_inference,
Tasks.summarization: self._text_gen_inference,
Tasks.visual_grounding: self._visual_grounding_inference,
Tasks.visual_entailment: inference_d[self.gen_type],
Tasks.visual_question_answering: inference_d[self.gen_type],
Tasks.text_classification: inference_d[self.gen_type],
Tasks.image_classification: inference_d[self.gen_type],
}

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
ret = self.task_inference_mapping[self.cfg.task](input)
ret['samples'] = input['samples']
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in ret and len(ret[key]) == 1:
ret[key] = ret[key][0]
if key not in ret:
ret[key] = None
return ret

def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
return input

def _text_gen_inference(self, input):
input = move_to_device(input, self._device)
gen_output = self.generator.generate([self.model], input)
gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))]
result = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
# text generation tasks have no score
ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result}
if self.cfg.task.endswith('classification'):
ret[OutputKeys.SCORES] = [1.0] * len(result)
return ret

def _visual_grounding_inference(self, input):
input = move_to_device(input, self._device)
gen_output = self.generator.generate([self.model], input)
tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))]
region_coord_l = list()
for i in range(len(tokens)):
region_coord_l.append(tokens[i][:-1]
- len(self.tokenizer.get_vocab().items())
+ self.cfg.num_bins)
region_tensor = torch.stack(region_coord_l, dim=0)
region_tensor = region_tensor / (
self.cfg.num_bins - 1) * self.cfg.model.get('max_image_size', 512)
region_tensor[:, ::2] /= input['w_resize_ratios']
region_tensor[:, 1::2] /= input['h_resize_ratios']
return {
OutputKeys.BOXES: move_to_device(region_tensor,
torch.device('cpu')),
OutputKeys.SCORES: [1.0] * region_tensor.shape[0]
}

def _traverse_inference(self, input):
input = move_to_device(input, self._device)
encoder_input = dict()
for key in input['net_input'].keys():
encoder_input[key] = input['net_input'][key]
encoder_out = self.model.encoder(**encoder_input)
valid_result = []
for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l):
valid_size = len(val_ans)
valid_tgt_items = [
torch.cat([
torch.tensor(decoder_prompt[1:]), valid_answer,
self.eos_item
]) for decoder_prompt in input['decoder_prompts']
for valid_answer in val_ans
]
valid_prev_items = [
torch.cat([torch.tensor(decoder_prompt), valid_answer])
for decoder_prompt in input['decoder_prompts']
for valid_answer in val_ans
]
valid_constraint_mask_items = [
torch.cat([
torch.zeros(
len(decoder_prompt) - 1,
valid_constraint_mask.size(1)).bool().to(self._device),
valid_constraint_mask], dim=0) # yapf: disable
for decoder_prompt in input['decoder_prompts'] # yapf: disable
for valid_constraint_mask in val_masks] # yapf: disable
valid_tgt = collate_tokens(
valid_tgt_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
valid_prev_output = collate_tokens(
valid_prev_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
val_masks = collate_tokens(
valid_constraint_mask_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
new_encoder_out = {
'last_hidden_state':
encoder_out['last_hidden_state'].repeat_interleave(
valid_size, dim=0),
'padding_mask':
encoder_out['padding_mask'].repeat_interleave(
valid_size, dim=0),
'position_embedding':
encoder_out['position_embedding'].repeat_interleave(
valid_size, dim=0)
}
encoder_attention_mask = expand_mask(
new_encoder_out['padding_mask'],
new_encoder_out['last_hidden_state'].dtype,
valid_prev_output.shape[-1])

decoder_out = self.model.decoder(
valid_prev_output,
encoder_hidden_states=new_encoder_out['last_hidden_state'],
encoder_attention_mask=encoder_attention_mask,
src_pos_embed=new_encoder_out['position_embedding'])

decoder_out[0].masked_fill_(~val_masks, -math.inf)
lprobs = self.model.get_normalized_probs(
decoder_out, log_probs=True)
scores = lprobs.gather(
dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
scores = scores.masked_fill(
valid_tgt.eq(self.tokenizer.pad_token_id), 0)
scores = scores.masked_fill((~val_masks).all(2), 0)
scores = scores.sum(1)
scores = scores.view(-1, valid_size)
valid_result.append(scores)
valid_result = torch.cat(valid_result, dim=-1)
predicts = valid_result.argmax(1).tolist()
probs = F.softmax(valid_result, dim=-1)
hyps = [self.index2ans[predict_index] for predict_index in predicts]
scores = [
float(prob[idx].cpu().detach().numpy())
for prob, idx in zip(probs, predicts)
]
return {OutputKeys.LABELS: hyps, OutputKeys.SCORES: scores}

def build_trie(self):
answer_item_list = []

for i, answer in enumerate(self.ans2label_dict.keys()):
answer_item = self.tokenizer(
' ' + answer, return_tensors='pt',
add_special_tokens=False).input_ids.squeeze(0)
answer_item_list.append(answer_item)
self.index2ans[i] = answer
self.constraint_trie.insert([self.tokenizer.bos_token_id]
+ answer_item.tolist()
+ [self.tokenizer.eos_token_id])

constraint_mask_list = []
for answer_item in answer_item_list:
constraint_mask = torch.zeros(
(len(answer_item) + 1,
len(self.tokenizer.get_vocab()))).bool()
for i in range(len(answer_item) + 1):
constraint_prefix_token = [self.tokenizer.bos_token_id
] + answer_item[:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
constraint_mask_list.append(constraint_mask)

for i in range(0, len(answer_item_list), self.val_batch_size):
self.val_ans_l += [answer_item_list[i:i + self.val_batch_size]]
self.val_masks_l += [
constraint_mask_list[i:i + self.val_batch_size]
]
self.val_ans_l = move_to_device(self.val_ans_l, self._device)
self.val_masks_l = move_to_device(self.val_masks_l, self._device)

def load_ans2label(self):
if self.cfg.model.get('answer2label', None):
filename = osp.join(self.model_dir, self.cfg.model.answer2label)
self.ans2label_dict = json.load(open(filename))

+ 0
- 53
modelscope/models/multi_modal/ofa_for_image_captioning_model.py View File

@@ -1,53 +0,0 @@
from typing import Any, Dict

import torch.cuda

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from .ofa import OFAModel, OFATokenizer
from .ofa.generate import sequence_generator as sg
from .ofa.generate.utils import move_to_device

__all__ = ['OfaForImageCaptioning']


@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa)
class OfaForImageCaptioning(Model):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
model = OFAModel.from_pretrained(model_dir)

self.model = model.module if hasattr(model, 'module') else model
self.tokenizer = OFATokenizer.from_pretrained(model_dir)
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
self._device = torch.device('cuda') if torch.cuda.is_available() \
else torch.device('cpu')
self.model.to(self._device)
# Initialize generator
sg_args = {
'tokenizer': self.tokenizer,
'beam_size': 5,
'max_len_b': 16,
'min_len': 1,
'no_repeat_ngram_size': 3,
'constraint_range': None
}
if hasattr(kwargs, 'beam_search'):
sg_args.update(kwargs['beam_search'])
self.generator = sg.SequenceGenerator(**sg_args)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self._device)
gen_output = self.generator.generate([self.model], input)
gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))]
result = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
return {'image_id': '42', OutputKeys.CAPTION: result[0]}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# What should we do here ?
return inputs

+ 2
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -24,6 +24,7 @@ if TYPE_CHECKING:
from .ocr_detection_pipeline import OCRDetectionPipeline
from .video_category_pipeline import VideoCategoryPipeline
from .virtual_tryon_pipeline import VirtualTryonPipeline
from .image_classification_pipeline import ImageClassificationPipeline
else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
@@ -33,7 +34,7 @@ else:
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
'image_classification_pipeline':
['GeneralImageClassificationPipeline'],
['GeneralImageClassificationPipeline', 'ImageClassificationPipeline'],
'image_cartoon_pipeline': ['ImageCartoonPipeline'],
'image_denoise_pipeline': ['ImageDenoisePipeline'],
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],


+ 32
- 6
modelscope/pipelines/cv/image_classification_pipeline.py View File

@@ -1,4 +1,5 @@
from typing import Any, Dict
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

import cv2
import numpy as np
@@ -7,16 +8,41 @@ import torch

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_classification, module_name=Pipelines.image_classification)
class ImageClassificationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs


@PIPELINES.register_module(
Tasks.image_classification_imagenet,
module_name=Pipelines.general_image_classification)
@@ -27,7 +53,7 @@ class GeneralImageClassificationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` and `preprocessor` to create a image classification pipeline for prediction
Args:
model: model id on modelscope hub.
"""


+ 6
- 2
modelscope/pipelines/multi_modal/__init__.py View File

@@ -5,7 +5,9 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .generative_multi_modal_embedding_pipeline import GEMMMultiModalEmbeddingPipeline
from .image_captioning_pipeline import ImageCaptionPipeline
from .image_captioning_pipeline import ImageCaptioningPipeline
from .visual_entailment_pipeline import VisualEntailmentPipeline
from .visual_grounding_pipeline import VisualGroundingPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .text_to_image_synthesis_pipeline import TextToImageSynthesisPipeline
from .video_multi_modal_embedding_pipeline import \
@@ -14,7 +16,9 @@ if TYPE_CHECKING:

else:
_import_structure = {
'image_captioning_pipeline': ['ImageCaptionPipeline'],
'image_captioning_pipeline': ['ImageCaptioningPipeline'],
'visual_entailment_pipeline': ['VisualEntailmentPipeline'],
'visual_grounding_pipeline': ['VisualGroundingPipeline'],
'multi_modal_embedding_pipeline': ['MultiModalEmbeddingPipeline'],
'text_to_image_synthesis_pipeline': ['TextToImageSynthesisPipeline'],
'visual_question_answering_pipeline':


+ 7
- 5
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

@@ -12,28 +13,29 @@ logger = get_logger()

@PIPELINES.register_module(
Tasks.image_captioning, module_name=Pipelines.image_captioning)
class ImageCaptionPipeline(Pipeline):
class ImageCaptioningPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
use `model` and `preprocessor` to create a image captioning pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForImageCaptioning'
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaImageCaptionPreprocessor(model_dir=model)
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 42
- 0
modelscope/pipelines/multi_modal/visual_entailment_pipeline.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.visual_entailment, module_name=Pipelines.visual_entailment)
class VisualEntailmentPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a visual entailment pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 42
- 0
modelscope/pipelines/multi_modal/visual_grounding_pipeline.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.visual_grounding, module_name=Pipelines.visual_grounding)
class VisualGroundingPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a visual grounding pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 11
- 5
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Optional, Union

import torch
@@ -30,15 +31,18 @@ class VisualQuestionAnsweringPipeline(Pipeline):
model (MPlugForVisualQuestionAnswering): a model instance
preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
MPlugForVisualQuestionAnswering) else Model.from_pretrained(model)
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
self.tokenizer = None
if preprocessor is None:
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
model.eval()
if isinstance(model, MPlugForVisualQuestionAnswering):
model.eval()
self.tokenizer = model.tokenizer
else:
model.model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = model.tokenizer

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
@@ -55,6 +59,8 @@ class VisualQuestionAnsweringPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
if self.tokenizer is None:
return inputs
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))


+ 4
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -17,6 +17,8 @@ if TYPE_CHECKING:
from .translation_pipeline import TranslationPipeline
from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .summarization_pipeline import SummarizationPipeline
from .text_classification_pipeline import TextClassificationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline

else:
@@ -38,6 +40,8 @@ else:
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'translation_pipeline': ['TranslationPipeline'],
'summarization_pipeline': ['SummarizationPipeline'],
'text_classification_pipeline': ['TextClassificationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline']
}



+ 42
- 0
modelscope/pipelines/nlp/summarization_pipeline.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.summarization, module_name=Pipelines.text_generation)
class SummarizationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 42
- 0
modelscope/pipelines/nlp/text_classification_pipeline.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.text_classification, module_name=Pipelines.text_classification)
class TextClassificationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForAllTasks'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
pipe_model.model.eval()
if preprocessor is None and pipe_model:
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 3
- 5
modelscope/preprocessors/__init__.py View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
ImageInstanceSegmentationPreprocessor,
ImageDenoisePreprocessor)
from .kws import WavToLists
from .multi_modal import (OfaImageCaptionPreprocessor,
from .multi_modal import (OfaPreprocessor,
MPlugVisualQuestionAnsweringPreprocessor)
from .nlp import (Tokenize, SequenceClassificationPreprocessor,
TextGenerationPreprocessor,
@@ -41,10 +41,8 @@ else:
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor'
],
'kws': ['WavToLists'],
'multi_modal': [
'OfaImageCaptionPreprocessor',
'MPlugVisualQuestionAnsweringPreprocessor'
],
'multi_modal':
['OfaPreprocessor', 'MPlugVisualQuestionAnsweringPreprocessor'],
'nlp': [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',


+ 25
- 39
modelscope/preprocessors/multi_modal.py View File

@@ -4,26 +4,25 @@ from typing import Any, Dict, Union

import torch
from PIL import Image
from torchvision import transforms

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
from modelscope.models.multi_modal.ofa import OFATokenizer
from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from .base import Preprocessor
from .builder import PREPROCESSORS
from .image import load_image
from .ofa import * # noqa
from .ofa.utils.collate import collate_fn

__all__ = [
'OfaImageCaptionPreprocessor',
'OfaPreprocessor',
'MPlugVisualQuestionAnsweringPreprocessor',
]


@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=Preprocessors.ofa_image_caption)
class OfaImageCaptionPreprocessor(Preprocessor):
class OfaPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -32,41 +31,28 @@ class OfaImageCaptionPreprocessor(Preprocessor):
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
preprocess_mapping = {
Tasks.image_captioning: OfaImageCaptioningPreprocessor,
Tasks.visual_grounding: OfaVisualGroundingPreprocessor,
Tasks.visual_question_answering:
OfaVisualQuestionAnsweringPreprocessor,
Tasks.visual_entailment: OfaVisualEntailmentPreprocessor,
Tasks.image_classification: OfaImageClassificationPreprocessor,
Tasks.text_classification: OfaTextClassificationPreprocessor,
Tasks.summarization: OfaSummarizationPreprocessor
}
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
model_dir)
self.tokenizer = OFATokenizer.from_pretrained(model_dir)
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])

# Initialize transform
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
patch_image_size = 480
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((patch_image_size, patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir)
self.tokenizer = self.preprocess.tokenizer

@type_assert(object, (str, tuple, Image.Image))
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:
if isinstance(data, Image.Image):
patch_image = self.patch_resize_transform(data).unsqueeze(0)
else:
patch_image = self.patch_resize_transform(
load_image(data)).unsqueeze(0)
text = ' what does the image describe?'
inputs = self.tokenizer([text], max_length=1024,
return_tensors='pt')['input_ids']
sample = dict()
sample['net_input'] = {
'input_ids': inputs,
'patch_images': patch_image,
'patch_masks': torch.tensor([True])
}
return sample
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self.preprocess(data)
sample['sample'] = data
return collate_fn([sample],
pad_idx=self.tokenizer.pad_token_id,
eos_idx=self.tokenizer.eos_token_id)


@PREPROCESSORS.register_module(


+ 8
- 0
modelscope/preprocessors/ofa/__init__.py View File

@@ -0,0 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image_captioning import OfaImageCaptioningPreprocessor
from .image_classification import OfaImageClassificationPreprocessor
from .summarization import OfaSummarizationPreprocessor
from .text_classification import OfaTextClassificationPreprocessor
from .visual_entailment import OfaVisualEntailmentPreprocessor
from .visual_grounding import OfaVisualGroundingPreprocessor
from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor

+ 117
- 0
modelscope/preprocessors/ofa/base.py View File

@@ -0,0 +1,117 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from os import path as osp

import json
import numpy as np
import torch

from modelscope.models.multi_modal.ofa import OFATokenizer
from modelscope.utils.trie import Trie
from .utils.random_help import set_torch_seed


class OfaBasePreprocessor:

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
self.cfg = cfg
tokenizer = OFATokenizer.from_pretrained(model_dir)
tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
self.tokenizer = tokenizer
self.bos_item = torch.LongTensor([tokenizer.bos_token_id])
self.pad_item = torch.LongTensor([tokenizer.pad_token_id])
self.eos_item = torch.LongTensor([tokenizer.eos_token_id])
self.tgt_dict = self.src_dict = {
value: key
for key, value in tokenizer.get_vocab().items()
}
self.max_src_length = cfg.model.get('max_src_length', 256)
self.max_image_size = cfg.model.get('max_image_size', 512)
self.language = self.cfg.model.get('language', 'en')
self.prompt_type = self.cfg.model.get('prompt_type', 'none')
seed = self.cfg.model.get('seed', 7)
np.random.seed(seed)
set_torch_seed(seed)
imagenet_default_mean_and_std = self.cfg.model.get(
'imagenet_default_mean_and_std', False)
if imagenet_default_mean_and_std:
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
else:
self.mean = [0.5, 0.5, 0.5]
self.std = [0.5, 0.5, 0.5]
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
self.constraint_trie = None
self.index2ans = {}
if self.cfg.model.get('answer2label', False):
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label)
ans2label_dict = json.load(open(ans2label_file, 'r'))
self.constraint_trie = Trie(tokenizer.eos_token_id)
for i, answer in enumerate(ans2label_dict.keys()):
answer_item = tokenizer(
' ' + answer,
return_tensors='pt',
add_special_tokens=False).input_ids.squeeze(0)
self.constraint_trie.insert([tokenizer.bos_token_id]
+ answer_item.tolist()
+ [tokenizer.eos_token_id])

def get_inputs(self, text, add_bos=True, add_eos=True):
inputs = self.tokenizer(
text,
max_length=self.max_src_length,
add_special_tokens=False,
return_tensors='pt')['input_ids'].squeeze(0)
if add_bos:
inputs = torch.cat([self.bos_item, inputs])
if add_eos:
inputs = torch.cat([inputs, self.eos_item])
return inputs

@staticmethod
def pre_caption(caption, max_words=None):
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\
.replace('/', ' ').replace('<person>', 'person')

caption = re.sub(
r'\s{2,}',
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')

# truncate caption
caption_words = caption.split(' ')
if max_words is not None and len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])

return caption

@staticmethod
def pre_question(question, max_ques_words):
question = question.lower().lstrip(',.!?*#:;~').replace('-',
' ').replace(
'/', ' ')

question = re.sub(
r'\s{2,}',
' ',
question,
)
question = question.rstrip('\n')
question = question.strip(' ')

# truncate question
question_words = question.split(' ')
if len(question_words) > max_ques_words:
question = ' '.join(question_words[:max_ques_words])

return question

+ 42
- 0
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union

import torch
from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from .base import OfaBasePreprocessor


class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.get_inputs(prompt)
sample = {
'source': inputs,
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
return sample

+ 43
- 0
modelscope/preprocessors/ofa/image_classification.py View File

@@ -0,0 +1,43 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch
from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from .base import OfaBasePreprocessor


class OfaImageClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaImageClassificationPreprocessor,
self).__init__(cfg, model_dir)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.get_inputs(prompt)
sample = {
'source': inputs,
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
return sample

+ 37
- 0
modelscope/preprocessors/ofa/summarization.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

from .base import OfaBasePreprocessor


class OfaSummarizationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
source = super().pre_caption(
data['text'], max_words=self.max_src_length)
source = source.strip()[:self.max_src_length]
source = source.replace('[unk]', 'unk').replace('<unk>', 'unk')
prompt = self.cfg.model.get(
'prompt', ' " {} " Summarize the article with a title: ')
text = prompt.format(source)
inputs = self.get_inputs(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'prev_output':
decoder_prompt = inputs[:-1]
else:
raise NotImplementedError
sample = {
'source': inputs,
'decoder_prompt': decoder_prompt,
}
return sample

+ 38
- 0
modelscope/preprocessors/ofa/text_classification.py View File

@@ -0,0 +1,38 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

from .base import OfaBasePreprocessor


class OfaTextClassificationPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
text1 = ' '.join(
data['text'].lower().strip().split()[:self.max_src_length])
text2 = ' '.join(
data['text2'].lower().strip().split()[:self.max_src_length])
prompt = ' can text1 " {} " imply text2 " {} "?'
text = prompt.format(text1, text2)
inputs = self.get_inputs(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':
decoder_prompt = inputs
elif self.prompt_type == 'prev_output':
decoder_prompt = inputs[:-1]
else:
raise NotImplementedError
sample = {
'source': inputs,
'decoder_prompt': decoder_prompt,
}
return sample

+ 0
- 0
modelscope/preprocessors/ofa/utils/__init__.py View File


+ 109
- 0
modelscope/preprocessors/ofa/utils/collate.py View File

@@ -0,0 +1,109 @@
import numpy as np
import torch


def collate_fn(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}

def merge(key):
return collate_tokens([s[key] for s in samples],
pad_idx,
eos_idx=eos_idx)

src_tokens = merge('source')

batch = {
'nsentences': len(samples),
'net_input': {
'input_ids': src_tokens,
},
}
if samples[0].get('id', None) is not None:
batch['id'] = np.array([s.get['id'] for s in samples])
if samples[0].get('target', None) is not None:
batch['target'] = merge('target')
tgt_lengths = torch.LongTensor(
[s['target'].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()
batch['ntokens'] = ntokens
if samples[0].get('prev_output_tokens', None) is not None:
batch['net_input']['decoder_input_ids'] = merge('prev_output_tokens')
if samples[0].get('patch_image', None) is not None:
batch['net_input']['patch_images'] = torch.stack(
[sample['patch_image'] for sample in samples], dim=0)
if samples[0].get('patch_mask', None) is not None:
batch['net_input']['patch_masks'] = torch.cat(
[sample['patch_mask'] for sample in samples])
# image generation
if samples[0].get('code_mask', None) is not None:
batch['net_input']['code_masks'] = torch.cat(
[sample['code_mask'] for sample in samples])
if samples[0].get('code_image', None) is not None:
batch['code_images'] = torch.cat(
[sample['code_image'] for sample in samples])
# For classification tasks (i.e., VQA, SNLI-VE, GLUE)
if samples[0].get('conf', None) is not None:
batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0)
if samples[0].get('ref_dict', None) is not None:
batch['ref_dict'] = np.array([s['ref_dict'] for s in samples])
if samples[0].get('constraint_mask', None) is not None:
batch['constraint_masks'] = merge('constraint_mask')
if samples[0].get('decoder_prompt', None) is not None:
batch['decoder_prompts'] = np.array(
[s['decoder_prompt'].tolist() for s in samples])
# For detection and visual grounding
if samples[0].get('w_resize_ratio', None) is not None:
batch['w_resize_ratios'] = torch.stack(
[s['w_resize_ratio'] for s in samples], dim=0)
if samples[0].get('h_resize_ratio', None) is not None:
batch['h_resize_ratios'] = torch.stack(
[s['h_resize_ratio'] for s in samples], dim=0)
if samples[0].get('region_coord', None) is not None:
batch['region_coords'] = torch.stack(
[s['region_coord'] for s in samples], dim=0)
if samples[0].get('sample', None) is not None:
batch['samples'] = [s['sample'] for s in samples]
return batch


def collate_tokens(
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)

def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)

if values[0].dim() == 1:
res = values[0].new(len(values), size).fill_(pad_idx)
elif values[0].dim() == 2:
assert move_eos_to_beginning is False
res = values[0].new(len(values), size,
values[0].size(1)).fill_(pad_idx)
else:
raise NotImplementedError

for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res

+ 42
- 0
modelscope/preprocessors/ofa/utils/random_help.py View File

@@ -0,0 +1,42 @@
import torch

try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None


def get_rng_state():
state = {'torch_rng_state': torch.get_rng_state()}
if xm is not None:
state['xla_rng_state'] = xm.get_rng_state()
if torch.cuda.is_available():
state['cuda_rng_state'] = torch.cuda.get_rng_state()
return state


def set_rng_state(state):
torch.set_rng_state(state['torch_rng_state'])
if xm is not None:
xm.set_rng_state(state['xla_rng_state'])
if torch.cuda.is_available():
torch.cuda.set_rng_state(state['cuda_rng_state'])


class set_torch_seed(object):

def __init__(self, seed):
assert isinstance(seed, int)
self.rng_state = get_rng_state()

torch.manual_seed(seed)
if xm is not None:
xm.set_rng_state(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

def __enter__(self):
return self

def __exit__(self, *exc):
set_rng_state(self.rng_state)

+ 557
- 0
modelscope/preprocessors/ofa/utils/transforms.py View File

@@ -0,0 +1,557 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import random

import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
from PIL import Image


def crop(image, target, region, delete=True):
cropped_image = F.crop(image, *region)

target = target.copy()
i, j, h, w = region

# should we do something wrt the original size?
target['size'] = torch.tensor([h, w])

fields = ['labels', 'area']

if 'boxes' in target:
boxes = target['boxes']
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
target['boxes'] = cropped_boxes.reshape(-1, 4)
target['area'] = area
fields.append('boxes')

if 'polygons' in target:
polygons = target['polygons']
num_polygons = polygons.shape[0]
max_size = torch.as_tensor([w, h], dtype=torch.float32)
start_coord = torch.cat([
torch.tensor([j, i], dtype=torch.float32)
for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable#
cropped_boxes = polygons - start_coord
cropped_boxes = torch.min(
cropped_boxes.reshape(num_polygons, -1, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
target['polygons'] = cropped_boxes.reshape(num_polygons, -1)
fields.append('polygons')

if 'masks' in target:
# FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w]
fields.append('masks')

# remove elements for which the boxes or masks that have zero area
if delete and ('boxes' in target or 'masks' in target):
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if 'boxes' in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
keep = torch.all(
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target['masks'].flatten(1).any(1)

for field in fields:
target[field] = target[field][keep.tolist()]

return cropped_image, target


def hflip(image, target):
flipped_image = F.hflip(image)
w, h = image.size
target = target.copy()
if 'boxes' in target:
boxes = target['boxes']
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
target['boxes'] = boxes

if 'polygons' in target:
polygons = target['polygons']
num_polygons = polygons.shape[0]
polygons = polygons.reshape(num_polygons, -1, 2) * torch.as_tensor(
[-1, 1]) + torch.as_tensor([w, 0])
target['polygons'] = polygons

if 'masks' in target:
target['masks'] = target['masks'].flip(-1)

return flipped_image, target


def resize(image, target, size, max_size=None):
# size can be min_size (scalar) or (w, h) tuple

def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size

if (w <= h and w == size) or (h <= w and h == size):
if max_size is not None:
max_size = int(max_size)
h = min(h, max_size)
w = min(w, max_size)
return (h, w)

if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)

if max_size is not None:
max_size = int(max_size)
oh = min(oh, max_size)
ow = min(ow, max_size)

return (oh, ow)

def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)

size = get_size(image.size, size, max_size)
rescaled_image = F.resize(image, size, interpolation=Image.BICUBIC)

if target is None:
return rescaled_image

ratios = tuple(
float(s) / float(s_orig)
for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios

target = target.copy()
if 'boxes' in target:
boxes = target['boxes']
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height])
target['boxes'] = scaled_boxes

if 'polygons' in target:
polygons = target['polygons']
scaled_ratio = torch.cat([
torch.tensor([ratio_width, ratio_height])
for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable
scaled_polygons = polygons * scaled_ratio
target['polygons'] = scaled_polygons

if 'area' in target:
area = target['area']
scaled_area = area * (ratio_width * ratio_height)
target['area'] = scaled_area

h, w = size
target['size'] = torch.tensor([h, w])

if 'masks' in target:
assert False

return rescaled_image, target


class CenterCrop(object):

def __init__(self, size):
self.size = size

def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, target,
(crop_top, crop_left, crop_height, crop_width))


class ObjectCenterCrop(object):

def __init__(self, size):
self.size = size

def __call__(self, img, target):
image_width, image_height = img.size
crop_height, crop_width = self.size

x0 = float(target['boxes'][0][0])
y0 = float(target['boxes'][0][1])
x1 = float(target['boxes'][0][2])
y1 = float(target['boxes'][0][3])

center_x = (x0 + x1) / 2
center_y = (y0 + y1) / 2
crop_left = max(
center_x - crop_width / 2
+ min(image_width - center_x - crop_width / 2, 0), 0)
crop_top = max(
center_y - crop_height / 2
+ min(image_height - center_y - crop_height / 2, 0), 0)

return crop(
img,
target, (crop_top, crop_left, crop_height, crop_width),
delete=False)


class RandomHorizontalFlip(object):

def __init__(self, p=0.5):
self.p = p

def __call__(self, img, target):
if random.random() < self.p:
return hflip(img, target)
return img, target


class RandomResize(object):

def __init__(self, sizes, max_size=None, equal=False):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.max_size = max_size
self.equal = equal

def __call__(self, img, target=None):
size = random.choice(self.sizes)
if self.equal:
return resize(img, target, size, size)
else:
return resize(img, target, size, self.max_size)


class ToTensor(object):

def __call__(self, img, target):
return F.to_tensor(img), target


class Normalize(object):

def __init__(self, mean, std, max_image_size=512):
self.mean = mean
self.std = std
self.max_image_size = max_image_size

def __call__(self, image, target=None):
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image, None
target = target.copy()
# h, w = image.shape[-2:]
h, w = target['size'][0], target['size'][1]
if 'boxes' in target:
boxes = target['boxes']
boxes = boxes / self.max_image_size
target['boxes'] = boxes
if 'polygons' in target:
polygons = target['polygons']
scale = torch.cat([
torch.tensor([w, h], dtype=torch.float32)
for _ in range(polygons.shape[1] // 2)], dim=0) # yapf: disable
polygons = polygons / scale
target['polygons'] = polygons
return image, target


class Compose(object):

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target

def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string


class LargeScaleJitter(object):
"""
implementation of large scale jitter from copy_paste
"""

def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0):
self.desired_size = torch.tensor([output_size])
self.aug_scale_min = aug_scale_min
self.aug_scale_max = aug_scale_max

def rescale_target(self, scaled_size, image_size, target):
# compute rescaled targets
image_scale = scaled_size / image_size
ratio_height, ratio_width = image_scale

target = target.copy()
target['size'] = scaled_size

if 'boxes' in target:
boxes = target['boxes']
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height])
target['boxes'] = scaled_boxes

if 'area' in target:
area = target['area']
scaled_area = area * (ratio_width * ratio_height)
target['area'] = scaled_area

if 'masks' in target:
assert False
masks = target['masks']
# masks = interpolate(
# masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5
target['masks'] = masks
return target

def crop_target(self, region, target):
i, j, h, w = region
fields = ['labels', 'area']

target = target.copy()
target['size'] = torch.tensor([h, w])

if 'boxes' in target:
boxes = target['boxes']
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(
cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :]
- cropped_boxes[:, 0, :]).prod(dim=1)
target['boxes'] = cropped_boxes.reshape(-1, 4)
target['area'] = area
fields.append('boxes')

if 'masks' in target:
# FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w]
fields.append('masks')

# remove elements for which the boxes or masks that have zero area
if 'boxes' in target or 'masks' in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if 'boxes' in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
keep = torch.all(
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target['masks'].flatten(1).any(1)

for field in fields:
target[field] = target[field][keep.tolist()]
return target

def pad_target(self, padding, target):
target = target.copy()
if 'masks' in target:
target['masks'] = torch.nn.functional.pad(
target['masks'], (0, padding[1], 0, padding[0]))
return target

def __call__(self, image, target=None):
image_size = image.size
image_size = torch.tensor(image_size[::-1])

random_scale = torch.rand(1) * (
self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min
scaled_size = (random_scale * self.desired_size).round()

scale = torch.maximum(scaled_size / image_size[0],
scaled_size / image_size[1])
scaled_size = (image_size * scale).round().int()

scaled_image = F.resize(
image, scaled_size.tolist(), interpolation=Image.BICUBIC)

if target is not None:
target = self.rescale_target(scaled_size, image_size, target)

# randomly crop or pad images
if random_scale >= 1:
# Selects non-zero random offset (x, y) if scaled image is larger than desired_size.
max_offset = scaled_size - self.desired_size
offset = (max_offset * torch.rand(2)).floor().int()
region = (offset[0].item(), offset[1].item(),
self.desired_size[0].item(), self.desired_size[0].item())
output_image = F.crop(scaled_image, *region)
if target is not None:
target = self.crop_target(region, target)
else:
assert False
padding = self.desired_size - scaled_size
output_image = F.pad(scaled_image,
[0, 0, padding[1].item(), padding[0].item()])
if target is not None:
target = self.pad_target(padding, target)

return output_image, target


class OriginLargeScaleJitter(object):
"""
implementation of large scale jitter from copy_paste
"""

def __init__(self, output_size=512, aug_scale_min=0.3, aug_scale_max=2.0):
self.desired_size = torch.tensor(output_size)
self.aug_scale_min = aug_scale_min
self.aug_scale_max = aug_scale_max

def rescale_target(self, scaled_size, image_size, target):
# compute rescaled targets
image_scale = scaled_size / image_size
ratio_height, ratio_width = image_scale

target = target.copy()
target['size'] = scaled_size

if 'boxes' in target:
boxes = target['boxes']
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height])
target['boxes'] = scaled_boxes

if 'area' in target:
area = target['area']
scaled_area = area * (ratio_width * ratio_height)
target['area'] = scaled_area

if 'masks' in target:
assert False
masks = target['masks']
# masks = interpolate(
# masks[:, None].float(), scaled_size, mode="nearest")[:, 0] > 0.5
target['masks'] = masks
return target

def crop_target(self, region, target):
i, j, h, w = region
fields = ['labels', 'area']

target = target.copy()
target['size'] = torch.tensor([h, w])

if 'boxes' in target:
boxes = target['boxes']
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(
cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
area = (cropped_boxes[:, 1, :]
- cropped_boxes[:, 0, :]).prod(dim=1)
target['boxes'] = cropped_boxes.reshape(-1, 4)
target['area'] = area
fields.append('boxes')

if 'masks' in target:
# FIXME should we update the area here if there are no boxes?
target['masks'] = target['masks'][:, i:i + h, j:j + w]
fields.append('masks')

# remove elements for which the boxes or masks that have zero area
if 'boxes' in target or 'masks' in target:
# favor boxes selection when defining which elements to keep
# this is compatible with previous implementation
if 'boxes' in target:
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
keep = torch.all(
cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
else:
keep = target['masks'].flatten(1).any(1)

for field in fields:
target[field] = target[field][keep.tolist()]
return target

def pad_target(self, padding, target):
target = target.copy()
if 'masks' in target:
target['masks'] = torch.nn.functional.pad(
target['masks'], (0, padding[1], 0, padding[0]))
return target

def __call__(self, image, target=None):
image_size = image.size
image_size = torch.tensor(image_size[::-1])

out_desired_size = (self.desired_size * image_size
/ max(image_size)).round().int()

random_scale = torch.rand(1) * (
self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min
scaled_size = (random_scale * self.desired_size).round()

scale = torch.minimum(scaled_size / image_size[0],
scaled_size / image_size[1])
scaled_size = (image_size * scale).round().int()

scaled_image = F.resize(image, scaled_size.tolist())

if target is not None:
target = self.rescale_target(scaled_size, image_size, target)

# randomly crop or pad images
if random_scale > 1:
# Selects non-zero random offset (x, y) if scaled image is larger than desired_size.
max_offset = scaled_size - out_desired_size
offset = (max_offset * torch.rand(2)).floor().int()
region = (offset[0].item(), offset[1].item(),
out_desired_size[0].item(), out_desired_size[1].item())
output_image = F.crop(scaled_image, *region)
if target is not None:
target = self.crop_target(region, target)
else:
padding = out_desired_size - scaled_size
output_image = F.pad(scaled_image,
[0, 0, padding[1].item(), padding[0].item()])
if target is not None:
target = self.pad_target(padding, target)

return output_image, target


class RandomDistortion(object):
"""
Distort image w.r.t hue, saturation and exposure.
"""

def __init__(self,
brightness=0,
contrast=0,
saturation=0,
hue=0,
prob=0.5):
self.prob = prob
self.tfm = T.ColorJitter(brightness, contrast, saturation, hue)

def __call__(self, img, target=None):
if np.random.random() < self.prob:
return self.tfm(img), target
else:
return img, target

+ 357
- 0
modelscope/preprocessors/ofa/utils/vision_helper.py View File

@@ -0,0 +1,357 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import cv2
import numpy as np


def identity_func(img):
return img


def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256

def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]

channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out


def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256

def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]

channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out


def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out


def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out


def color_func(img, factor):
# same output as PIL.ImageEnhance.Color
M = (
np.float32([[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]]))
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out


def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(el - mean) * factor + mean
for el in range(256)]).clip(0, 255).astype(np.uint8)
out = table[img]
return out


def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(
np.uint8)
out = table[img]
return out


def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (
out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out


def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out


def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out


def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out


def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out


def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out


def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out


# level to args
def enhance_level_to_args(MAX_LEVEL):

def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1, )

return level_to_args


def shear_level_to_args(MAX_LEVEL, replace_value):

def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return level, replace_value

return level_to_args


def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):

def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)

return level_to_args


def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):

def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)

return level_to_args


def solarize_level_to_args(MAX_LEVEL):

def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )

return level_to_args


def none_level_to_args(level):
return ()


def posterize_level_to_args(MAX_LEVEL):

def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )

return level_to_args


def rotate_level_to_args(MAX_LEVEL, replace_value):

def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)

return level_to_args


func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}

translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity':
none_level_to_args,
'AutoContrast':
none_level_to_args,
'Equalize':
none_level_to_args,
'Rotate':
rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize':
solarize_level_to_args(MAX_LEVEL),
'Color':
enhance_level_to_args(MAX_LEVEL),
'Contrast':
enhance_level_to_args(MAX_LEVEL),
'Brightness':
enhance_level_to_args(MAX_LEVEL),
'Sharpness':
enhance_level_to_args(MAX_LEVEL),
'ShearX':
shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX':
translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
'TranslateY':
translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
'Posterize':
posterize_level_to_args(MAX_LEVEL),
'ShearY':
shear_level_to_args(MAX_LEVEL, replace_value),
}


class RandomAugment(object):

def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())

def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]

def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img

+ 62
- 0
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch
from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from .base import OfaBasePreprocessor


class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
if 'text2' not in data:
hypothesis = self.pre_caption(data['text'], self.max_src_length)
prompt = self.cfg.model.get('prompt',
' does the image describe " {} "?')
text = prompt.format(hypothesis)
else:
assert 'text' in data, f'text must be in the input {data.keys()}'
caption = self.pre_caption(data['text2'], self.max_src_length)
hypothesis = self.pre_caption(data['text'], self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
text = prompt.format(caption, hypothesis)
inputs = self.get_inputs(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':
decoder_prompt = inputs
elif self.prompt_type == 'prev_output':
decoder_prompt = inputs[:-1]
else:
raise NotImplementedError
sample = {
'source': inputs,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
return sample

+ 50
- 0
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -0,0 +1,50 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch
from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from .base import OfaBasePreprocessor


class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
w, h = image.size
patch_image = self.patch_resize_transform(image)
w_resize_ratio = torch.tensor(self.patch_image_size / w)
h_resize_ratio = torch.tensor(self.patch_image_size / h)
src_caption = self.pre_caption(data['text'], self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
src_item = self.get_inputs(text)
sample = {
'source': src_item,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'w_resize_ratio': w_resize_ratio,
'h_resize_ratio': h_resize_ratio,
}
return sample

+ 52
- 0
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -0,0 +1,52 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch
from PIL import Image
from torchvision import transforms

from modelscope.preprocessors.image import load_image
from .base import OfaBasePreprocessor


class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):

def __init__(self, cfg, model_dir):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
"""
super(OfaVisualQuestionAnsweringPreprocessor,
self).__init__(cfg, model_dir)
# Initialize transform
self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize((self.patch_image_size, self.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data['text'])
inputs = self.get_inputs(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':
decoder_prompt = inputs
elif self.prompt_type == 'prev_output':
decoder_prompt = inputs[:-1]
else:
raise NotImplementedError
sample = {
'source': inputs,
'patch_image': patch_image,
'patch_mask': torch.tensor([True]),
'decoder_prompt': decoder_prompt,
}
return sample

+ 1
- 0
modelscope/utils/constant.py View File

@@ -85,6 +85,7 @@ class MultiModalTasks(object):
multi_modal_embedding = 'multi-modal-embedding'
generative_multi_modal_embedding = 'generative-multi-modal-embedding'
visual_question_answering = 'visual-question-answering'
visual_entailment = 'visual-entailment'
video_multi_modal_embedding = 'video-multi-modal-embedding'




+ 29
- 0
modelscope/utils/trie.py View File

@@ -0,0 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from collections import defaultdict


class TreeNode:

def __init__(self):
self.child = defaultdict(TreeNode)


class Trie:

def __init__(self, eos):
self.root = TreeNode()
self.eos = eos

def insert(self, word):
cur = self.root
for c in word:
cur = cur.child[c]

def get_next_layer(self, word):
cur = self.root
for c in word:
cur = cur.child.get(c)
if cur is None:
return [self.eos]
return list(cur.child.keys())

+ 0
- 23
tests/pipelines/test_image_captioning.py View File

@@ -1,23 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ImageCaptionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run(self):
img_captioning = pipeline(
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_distilled_en')
result = img_captioning('data/test/images/image_captioning.png')
print(result[OutputKeys.CAPTION])


if __name__ == '__main__':
unittest.main()

+ 179
- 0
tests/pipelines/test_ofa_tasks.py View File

@@ -0,0 +1,179 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class OfaTasksTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-caption_coco_distilled_en')
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
)
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_name(self):
img_captioning = pipeline(
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_distilled_en')
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_classification_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-classification_imagenet_large_en')
ofa_pipe = pipeline(Tasks.image_classification, model=model)
image = 'data/test/images/image_classification.png'
input = {'image': image}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_classification_with_name(self):
ofa_pipe = pipeline(
Tasks.image_classification,
model='damo/ofa_image-classification_imagenet_large_en')
image = 'data/test/images/image_classification.png'
input = {'image': image}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_summarization_with_model(self):
model = Model.from_pretrained(
'damo/ofa_summarization_gigaword_large_en')
ofa_pipe = pipeline(Tasks.summarization, model=model)
text = 'five-time world champion michelle kwan withdrew' + \
'from the #### us figure skating championships on wednesday ,' + \
' but will petition us skating officials for the chance to ' + \
'compete at the #### turin olympics .'
input = {'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_summarization_with_name(self):
ofa_pipe = pipeline(
Tasks.summarization,
model='damo/ofa_summarization_gigaword_large_en')
text = 'five-time world champion michelle kwan withdrew' + \
'from the #### us figure skating championships on wednesday ,' + \
' but will petition us skating officials for the chance to ' +\
'compete at the #### turin olympics .'
input = {'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_text_classification_with_model(self):
model = Model.from_pretrained(
'damo/ofa_text-classification_mnli_large_en')
ofa_pipe = pipeline(Tasks.text_classification, model=model)
text = 'One of our number will carry out your instructions minutely.'
text2 = 'A member of my team will execute your orders with immense precision.'
input = {'text': text, 'text2': text2}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_text_classification_with_name(self):
ofa_pipe = pipeline(
Tasks.text_classification,
model='damo/ofa_text-classification_mnli_large_en')
text = 'One of our number will carry out your instructions minutely.'
text2 = 'A member of my team will execute your orders with immense precision.'
input = {'text': text, 'text2': text2}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_entailment_with_model(self):
model = Model.from_pretrained(
'damo/ofa_visual-entailment_snli-ve_large_en')
ofa_pipe = pipeline(Tasks.visual_entailment, model=model)
image = 'data/test/images/dogs.jpg'
text = 'there are two birds.'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_entailment_with_name(self):
ofa_pipe = pipeline(
Tasks.visual_entailment,
model='damo/ofa_visual-entailment_snli-ve_large_en')
image = 'data/test/images/dogs.jpg'
text = 'there are two birds.'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_with_model(self):
model = Model.from_pretrained(
'damo/ofa_visual-grounding_refcoco_large_en')
ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
image = 'data/test/images/visual_grounding.png'
text = 'a blue turtle-like pokemon with round head'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_with_name(self):
ofa_pipe = pipeline(
Tasks.visual_grounding,
model='damo/ofa_visual-grounding_refcoco_large_en')
image = 'data/test/images/visual_grounding.png'
text = 'a blue turtle-like pokemon with round head'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_question_answering_with_model(self):
from modelscope.preprocessors.multi_modal import OfaPreprocessor
model = Model.from_pretrained(
'damo/ofa_visual-question-answering_pretrain_large_en')
preprocessor = OfaPreprocessor(model_dir=model.model_dir)
ofa_pipe = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
image = 'data/test/images/visual_question_answering.png'
text = 'what is grown on the plant?'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_question_answering_with_name(self):
from modelscope.preprocessors.multi_modal import OfaPreprocessor
model = 'damo/ofa_visual-question-answering_pretrain_large_en'
preprocessor = OfaPreprocessor(model_dir=model)
ofa_pipe = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
image = 'data/test/images/visual_question_answering.png'
text = 'what is grown on the plant?'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save