diff --git a/data/test/images/image_ocr_recognition.jpg b/data/test/images/image_ocr_recognition.jpg new file mode 100644 index 00000000..b41287cd --- /dev/null +++ b/data/test/images/image_ocr_recognition.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:772b19f76c98044e39330853928624f10e085106a4292b4dd19f865531080747 +size 959 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 31bef3b8..732f8ffa 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -263,6 +263,7 @@ class Pipelines(object): text_to_image_synthesis = 'text-to-image-synthesis' video_multi_modal_embedding = 'video-multi-modal-embedding' image_text_retrieval = 'image-text-retrieval' + ofa_ocr_recognition = 'ofa-ocr-recognition' class Trainers(object): diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py index 984da443..eec2cc6c 100644 --- a/modelscope/models/multi_modal/ofa/utils/constant.py +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -3,6 +3,7 @@ from modelscope.outputs import OutputKeys from modelscope.utils.constant import Tasks OFA_TASK_KEY_MAPPING = { + Tasks.ofa_ocr_recognition: OutputKeys.TEXT, Tasks.image_captioning: OutputKeys.CAPTION, Tasks.summarization: OutputKeys.TEXT, Tasks.visual_question_answering: OutputKeys.TEXT, diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 45bafde9..20cab6a6 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -27,6 +27,7 @@ __all__ = ['OfaForAllTasks'] @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) +@MODELS.register_module(Tasks.ofa_ocr_recognition, module_name=Models.ofa) @MODELS.register_module(Tasks.visual_grounding, module_name=Models.ofa) @MODELS.register_module( Tasks.visual_question_answering, module_name=Models.ofa) @@ -96,6 +97,7 @@ class OfaForAllTasks(TorchModel): 'traverse': self._traverse_inference, } self.task_inference_mapping = { + Tasks.ofa_ocr_recognition: self._text_gen_inference, Tasks.image_captioning: self._text_gen_inference, Tasks.summarization: self._text_gen_inference, Tasks.visual_grounding: self._visual_grounding_inference, diff --git a/modelscope/outputs.py b/modelscope/outputs.py index fbe15646..365e2bf9 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -661,6 +661,7 @@ TASK_OUTPUTS = { # "caption": "this is an image caption text." # } Tasks.image_captioning: [OutputKeys.CAPTION], + Tasks.ofa_ocr_recognition: [OutputKeys.TEXT], # visual grounding result for single sample # { diff --git a/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py new file mode 100644 index 00000000..9cd63b6c --- /dev/null +++ b/modelscope/pipelines/multi_modal/ocr_recognition_pipeline.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.ofa_ocr_recognition, module_name=Pipelines.ofa_ocr_recognition) +class OcrRecognitionPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a ocr recognition pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or OfaForAllTasks' + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError + pipe_model.model.eval() + if preprocessor is None: + if isinstance(pipe_model, OfaForAllTasks): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index f38ff8ae..6f3245c3 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -34,6 +34,7 @@ class OfaPreprocessor(Preprocessor): """ super().__init__(*args, **kwargs) preprocess_mapping = { + Tasks.ofa_ocr_recognition: OfaOcrRecognitionPreprocessor, Tasks.image_captioning: OfaImageCaptioningPreprocessor, Tasks.visual_grounding: OfaVisualGroundingPreprocessor, Tasks.visual_question_answering: @@ -45,6 +46,7 @@ class OfaPreprocessor(Preprocessor): Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor } input_key_mapping = { + Tasks.ofa_ocr_recognition: ['image'], Tasks.image_captioning: ['image'], Tasks.image_classification: ['image'], Tasks.summarization: ['text'], diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py index 95d72fe1..59b94b2b 100644 --- a/modelscope/preprocessors/ofa/__init__.py +++ b/modelscope/preprocessors/ofa/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .image_captioning import OfaImageCaptioningPreprocessor from .image_classification import OfaImageClassificationPreprocessor +from .ocr_recognition import OfaOcrRecognitionPreprocessor from .summarization import OfaSummarizationPreprocessor from .text_classification import OfaTextClassificationPreprocessor from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py new file mode 100644 index 00000000..1d30e572 --- /dev/null +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -0,0 +1,99 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import random +import unicodedata +from typing import Any, Dict, Union + +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F + +from modelscope.preprocessors.image import load_image +from .base import OfaBasePreprocessor + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def ocr_resize(img, patch_image_size, is_document=False): + img = img.convert('RGB') + width, height = img.size + + if is_document: + new_height, new_width = 64, 1920 + else: + if width >= height: + new_width = max(64, patch_image_size) + new_height = max(64, int(patch_image_size * (height / width))) + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + else: + new_height = max(64, patch_image_size) + new_width = max(64, int(patch_image_size * (width / height))) + left = (patch_image_size - new_width) // 2 + right = patch_image_size - new_width - left + top, bottom = 0, 0 + + img_new = F.resize( + img, + (new_height, new_width), + interpolation=InterpolationMode.BICUBIC, + ) + + if is_document: + img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1) + img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2)) + new_width, new_height = img_new.size + top = (patch_image_size - new_height) // 2 + bottom = patch_image_size - new_height - top + left, right = 0, 0 + + img_new = F.pad( + img_new, padding=[left, top, right, bottom], padding_mode='edge') + assert img_new.size == (patch_image_size, patch_image_size) + + return img_new + + +class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data + + Args: + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path + """ + super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir) + # Initialize transform + if self.cfg.model.imagenet_default_mean_and_std: + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + else: + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + self.patch_resize_transform = transforms.Compose([ + lambda image: ocr_resize( + image, + self.cfg.model.patch_image_size, + is_document=self.cfg.model.is_document), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = data['image'] if isinstance( + data['image'], Image.Image) else load_image(data['image']) + patch_image = self.patch_resize_transform(image) + prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') + inputs = self.get_inputs(prompt) + + sample = { + 'source': inputs, + 'patch_image': patch_image, + 'patch_mask': torch.tensor([True]) + } + return sample diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6c0f3e98..865e1d4f 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -151,6 +151,7 @@ class MultiModalTasks(object): visual_entailment = 'visual-entailment' video_multi_modal_embedding = 'video-multi-modal-embedding' image_text_retrieval = 'image-text-retrieval' + ofa_ocr_recognition = 'ofa-ocr-recognition' class TasksIODescriptions(object): diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index f8366508..05ecc719 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -45,6 +45,14 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_ocr_recognize_with_name(self): + ocr_recognize = pipeline( + Tasks.ofa_ocr_recognition, + model='damo/ofa_ocr-recognition_scene_base_zh') + result = ocr_recognize('data/test/images/image_ocr_recognition.jpg') + print(result[OutputKeys.TEXT]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_image_classification_with_model(self): model = Model.from_pretrained(