新增ofa关于日常场景文字识别的任务,主要包括: 1、新增pipeline及task名称定义; 2、新增pipeline、task、model及prepreocess核心类方法的代码逻辑; 3、其它同步修正的小细节逻辑; Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10471089master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 | |||||
size 959 |
@@ -263,6 +263,7 @@ class Pipelines(object): | |||||
text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
class Trainers(object): | class Trainers(object): | ||||
@@ -3,6 +3,7 @@ from modelscope.outputs import OutputKeys | |||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
OFA_TASK_KEY_MAPPING = { | OFA_TASK_KEY_MAPPING = { | ||||
Tasks.ofa_ocr_recognition: OutputKeys.TEXT, | |||||
Tasks.image_captioning: OutputKeys.CAPTION, | Tasks.image_captioning: OutputKeys.CAPTION, | ||||
Tasks.summarization: OutputKeys.TEXT, | Tasks.summarization: OutputKeys.TEXT, | ||||
Tasks.visual_question_answering: OutputKeys.TEXT, | Tasks.visual_question_answering: OutputKeys.TEXT, | ||||
@@ -27,6 +27,7 @@ __all__ = ['OfaForAllTasks'] | |||||
@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | ||||
@MODELS.register_module(Tasks.ofa_ocr_recognition, module_name=Models.ofa) | |||||
@MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) | ||||
@MODELS.register_module( | @MODELS.register_module( | ||||
Tasks.visual_question_answering, module_name=Models.ofa) | Tasks.visual_question_answering, module_name=Models.ofa) | ||||
@@ -96,6 +97,7 @@ class OfaForAllTasks(TorchModel): | |||||
'traverse': self._traverse_inference, | 'traverse': self._traverse_inference, | ||||
} | } | ||||
self.task_inference_mapping = { | self.task_inference_mapping = { | ||||
Tasks.ofa_ocr_recognition: self._text_gen_inference, | |||||
Tasks.image_captioning: self._text_gen_inference, | Tasks.image_captioning: self._text_gen_inference, | ||||
Tasks.summarization: self._text_gen_inference, | Tasks.summarization: self._text_gen_inference, | ||||
Tasks.visual_grounding: self._visual_grounding_inference, | Tasks.visual_grounding: self._visual_grounding_inference, | ||||
@@ -661,6 +661,7 @@ TASK_OUTPUTS = { | |||||
# "caption": "this is an image caption text." | # "caption": "this is an image caption text." | ||||
# } | # } | ||||
Tasks.image_captioning: [OutputKeys.CAPTION], | Tasks.image_captioning: [OutputKeys.CAPTION], | ||||
Tasks.ofa_ocr_recognition: [OutputKeys.TEXT], | |||||
# visual grounding result for single sample | # visual grounding result for single sample | ||||
# { | # { | ||||
@@ -0,0 +1,52 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Any, Dict, Optional, Union | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.multi_modal import OfaForAllTasks | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Model, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import OfaPreprocessor, Preprocessor | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.ofa_ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) | |||||
class OcrRecognitionPipeline(Pipeline): | |||||
def __init__(self, | |||||
model: Union[Model, str], | |||||
preprocessor: Optional[Preprocessor] = None, | |||||
**kwargs): | |||||
""" | |||||
use `model` and `preprocessor` to create a ocr recognition pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model) | |||||
assert isinstance(model, str) or isinstance(model, Model), \ | |||||
'model must be a single str or OfaForAllTasks' | |||||
if isinstance(model, str): | |||||
pipe_model = Model.from_pretrained(model) | |||||
elif isinstance(model, Model): | |||||
pipe_model = model | |||||
else: | |||||
raise NotImplementedError | |||||
pipe_model.model.eval() | |||||
if preprocessor is None: | |||||
if isinstance(pipe_model, OfaForAllTasks): | |||||
preprocessor = OfaPreprocessor(pipe_model.model_dir) | |||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | |||||
def forward(self, inputs: Dict[str, Any], | |||||
**forward_params) -> Dict[str, Any]: | |||||
with torch.no_grad(): | |||||
return super().forward(inputs, **forward_params) | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -34,6 +34,7 @@ class OfaPreprocessor(Preprocessor): | |||||
""" | """ | ||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
preprocess_mapping = { | preprocess_mapping = { | ||||
Tasks.ofa_ocr_recognition: OfaOcrRecognitionPreprocessor, | |||||
Tasks.image_captioning: OfaImageCaptioningPreprocessor, | Tasks.image_captioning: OfaImageCaptioningPreprocessor, | ||||
Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | Tasks.visual_grounding: OfaVisualGroundingPreprocessor, | ||||
Tasks.visual_question_answering: | Tasks.visual_question_answering: | ||||
@@ -45,6 +46,7 @@ class OfaPreprocessor(Preprocessor): | |||||
Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor | ||||
} | } | ||||
input_key_mapping = { | input_key_mapping = { | ||||
Tasks.ofa_ocr_recognition: ['image'], | |||||
Tasks.image_captioning: ['image'], | Tasks.image_captioning: ['image'], | ||||
Tasks.image_classification: ['image'], | Tasks.image_classification: ['image'], | ||||
Tasks.summarization: ['text'], | Tasks.summarization: ['text'], | ||||
@@ -1,6 +1,7 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from .image_captioning import OfaImageCaptioningPreprocessor | from .image_captioning import OfaImageCaptioningPreprocessor | ||||
from .image_classification import OfaImageClassificationPreprocessor | from .image_classification import OfaImageClassificationPreprocessor | ||||
from .ocr_recognition import OfaOcrRecognitionPreprocessor | |||||
from .summarization import OfaSummarizationPreprocessor | from .summarization import OfaSummarizationPreprocessor | ||||
from .text_classification import OfaTextClassificationPreprocessor | from .text_classification import OfaTextClassificationPreprocessor | ||||
from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor | ||||
@@ -0,0 +1,99 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import random | |||||
import unicodedata | |||||
from typing import Any, Dict, Union | |||||
import torch | |||||
from PIL import Image | |||||
from torchvision import transforms | |||||
from torchvision.transforms import InterpolationMode | |||||
from torchvision.transforms import functional as F | |||||
from modelscope.preprocessors.image import load_image | |||||
from .base import OfaBasePreprocessor | |||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |||||
def ocr_resize(img, patch_image_size, is_document=False): | |||||
img = img.convert('RGB') | |||||
width, height = img.size | |||||
if is_document: | |||||
new_height, new_width = 64, 1920 | |||||
else: | |||||
if width >= height: | |||||
new_width = max(64, patch_image_size) | |||||
new_height = max(64, int(patch_image_size * (height / width))) | |||||
top = (patch_image_size - new_height) // 2 | |||||
bottom = patch_image_size - new_height - top | |||||
left, right = 0, 0 | |||||
else: | |||||
new_height = max(64, patch_image_size) | |||||
new_width = max(64, int(patch_image_size * (width / height))) | |||||
left = (patch_image_size - new_width) // 2 | |||||
right = patch_image_size - new_width - left | |||||
top, bottom = 0, 0 | |||||
img_new = F.resize( | |||||
img, | |||||
(new_height, new_width), | |||||
interpolation=InterpolationMode.BICUBIC, | |||||
) | |||||
if is_document: | |||||
img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) | |||||
img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) | |||||
new_width, new_height = img_new.size | |||||
top = (patch_image_size - new_height) // 2 | |||||
bottom = patch_image_size - new_height - top | |||||
left, right = 0, 0 | |||||
img_new = F.pad( | |||||
img_new, padding=[left, top, right, bottom], padding_mode='edge') | |||||
assert img_new.size == (patch_image_size, patch_image_size) | |||||
return img_new | |||||
class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
def __init__(self, cfg, model_dir): | |||||
"""preprocess the data | |||||
Args: | |||||
cfg(modelscope.utils.config.ConfigDict) : model config | |||||
model_dir (str): model path | |||||
""" | |||||
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) | |||||
# Initialize transform | |||||
if self.cfg.model.imagenet_default_mean_and_std: | |||||
mean = IMAGENET_DEFAULT_MEAN | |||||
std = IMAGENET_DEFAULT_STD | |||||
else: | |||||
mean = [0.5, 0.5, 0.5] | |||||
std = [0.5, 0.5, 0.5] | |||||
self.patch_resize_transform = transforms.Compose([ | |||||
lambda image: ocr_resize( | |||||
image, | |||||
self.cfg.model.patch_image_size, | |||||
is_document=self.cfg.model.is_document), | |||||
transforms.ToTensor(), | |||||
transforms.Normalize(mean=mean, std=std), | |||||
]) | |||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
image = data['image'] if isinstance( | |||||
data['image'], Image.Image) else load_image(data['image']) | |||||
patch_image = self.patch_resize_transform(image) | |||||
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | |||||
inputs = self.get_inputs(prompt) | |||||
sample = { | |||||
'source': inputs, | |||||
'patch_image': patch_image, | |||||
'patch_mask': torch.tensor([True]) | |||||
} | |||||
return sample |
@@ -151,6 +151,7 @@ class MultiModalTasks(object): | |||||
visual_entailment = 'visual-entailment' | visual_entailment = 'visual-entailment' | ||||
video_multi_modal_embedding = 'video-multi-modal-embedding' | video_multi_modal_embedding = 'video-multi-modal-embedding' | ||||
image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
ofa_ocr_recognition = 'ofa-ocr-recognition' | |||||
class TasksIODescriptions(object): | class TasksIODescriptions(object): | ||||
@@ -45,6 +45,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
result = img_captioning('data/test/images/image_captioning.png') | result = img_captioning('data/test/images/image_captioning.png') | ||||
print(result[OutputKeys.CAPTION]) | print(result[OutputKeys.CAPTION]) | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_ocr_recognize_with_name(self): | |||||
ocr_recognize = pipeline( | |||||
Tasks.ofa_ocr_recognition, | |||||
model='damo/ofa_ocr-recognition_scene_base_zh') | |||||
result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') | |||||
print(result[OutputKeys.TEXT]) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_run_with_image_classification_with_model(self): | def test_run_with_image_classification_with_model(self): | ||||
model = Model.from_pretrained( | model = Model.from_pretrained( | ||||