From cd577f8ba07c26dffc8520ef577b5490fa7a4ce2 Mon Sep 17 00:00:00 2001 From: "menrui.mr" Date: Thu, 4 Aug 2022 15:33:31 +0800 Subject: [PATCH] [to #42322933] Add ofa-text-to-image-synthesis to maas lib Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590940 --- modelscope/metainfo.py | 1 + modelscope/models/multi_modal/__init__.py | 4 +- .../ofa_for_text_to_image_synthesis_model.py | 90 +++++++++++++++++++ .../text_to_image_synthesis_pipeline.py | 26 ++++-- modelscope/preprocessors/multi_modal.py | 6 +- modelscope/preprocessors/ofa/__init__.py | 1 + .../ofa/text_to_image_synthesis.py | 31 +++++++ requirements/multi-modal.txt | 1 + tests/pipelines/test_ofa_tasks.py | 19 ++++ 9 files changed, 170 insertions(+), 9 deletions(-) create mode 100644 modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py create mode 100644 modelscope/preprocessors/ofa/text_to_image_synthesis.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index da0cb0e8..9d9b255a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -203,6 +203,7 @@ class Preprocessors(object): # multi-modal ofa_image_caption = 'ofa-image-caption' + ofa_text_to_image_synthesis = 'ofa-text-to-image-synthesis' mplug_visual_question_answering = 'mplug-visual-question-answering' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index cd368739..0f9c9e85 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -20,7 +20,9 @@ else: 'mmr': ['VideoCLIPForMultiModalEmbedding'], 'mplug_for_visual_question_answering': ['MPlugForVisualQuestionAnswering'], - 'ofa_for_all_tasks': ['OfaForAllTasks'] + 'ofa_for_all_tasks': ['OfaForAllTasks'], + 'ofa_for_text_to_image_synthesis_model': + ['OfaForTextToImageSynthesis'] } import sys diff --git a/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py new file mode 100644 index 00000000..5cdc9668 --- /dev/null +++ b/modelscope/models/multi_modal/ofa_for_text_to_image_synthesis_model.py @@ -0,0 +1,90 @@ +import os +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.cuda +from PIL import Image +from taming.models.vqgan import GumbelVQ, VQModel + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.ofa import OFAModel, OFATokenizer +from modelscope.models.multi_modal.ofa.generate import sequence_generator as sg +from modelscope.models.multi_modal.ofa.generate.search import Sampling +from modelscope.models.multi_modal.ofa.generate.utils import move_to_device +from modelscope.utils.constant import Tasks + +__all__ = ['OfaForTextToImageSynthesis'] + + +def custom_to_pil(x): + x = x.detach().cpu() + x = torch.clamp(x, -1., 1.) + x = (x + 1.) / 2. + x = x.permute(1, 2, 0).numpy() + x = (255 * x).astype(np.uint8) + x = Image.fromarray(x) + if not x.mode == 'RGB': + x = x.convert('RGB') + return x + + +def load_vqgan(config, ckpt_path=None, is_gumbel=False): + if is_gumbel: + model = GumbelVQ(**config['model']['params']) + else: + model = VQModel(**config['model']['params']) + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location='cpu')['state_dict'] + missing, unexpected = model.load_state_dict(sd, strict=False) + return model.eval() + + +@MODELS.register_module(Tasks.text_to_image_synthesis, module_name=Models.ofa) +class OfaForTextToImageSynthesis(Model): + + def __init__(self, model_dir, *args, **kwargs): + super().__init__(model_dir=model_dir, *args, **kwargs) + # Initialize ofa + model = OFAModel.from_pretrained(model_dir) + self.model = model.module if hasattr(model, 'module') else model + self.tokenizer = OFATokenizer.from_pretrained(model_dir) + self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + self._device = torch.device('cuda') if torch.cuda.is_available() \ + else torch.device('cpu') + self.model.to(self._device) + + # Initialize vqgan + vqgan_config = json.load( + open(os.path.join(model_dir, 'vqgan_config.json'))) + self.vqgan_model = load_vqgan( + vqgan_config, + ckpt_path=os.path.join(model_dir, 'vqgan_model.ckpt'), + is_gumbel=True).to(self._device) + # Initialize generator + sampling = Sampling(self.tokenizer, sampling_topp=0.9) + sg_args = { + 'tokenizer': self.tokenizer, + 'beam_size': 1, + 'max_len_b': 1024, + 'min_len': 1024, + 'search_strategy': sampling, + 'gen_code': True, + 'constraint_range': '50265,58457' + } + self.generator = sg.SequenceGenerator(**sg_args) + + def forward(self, input: Dict[str, Any]): + input = move_to_device(input, self._device) + gen_output = self.generator.generate([self.model], input) + gen_tokens = gen_output[0][0]['tokens'][:-1] + codes = gen_tokens.view(1, 32, 32) - 50265 + quant_b = self.vqgan_model.quantize.get_codebook_entry( + codes.view(-1), + list(codes.size()) + [self.vqgan_model.quantize.embedding_dim]) + dec = self.vqgan_model.decode(quant_b)[0] + return custom_to_pil(dec) diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py index 44625b0a..406538cf 100644 --- a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -1,11 +1,13 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import torch from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import OfaForTextToImageSynthesis from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -17,7 +19,10 @@ logger = get_logger() module_name=Pipelines.text_to_image_synthesis) class TextToImageSynthesisPipeline(Pipeline): - def __init__(self, model: str, **kwargs): + def __init__(self, + model: str, + preprocessor: Optional[Preprocessor] = None, + **kwargs): """ use `model` and `preprocessor` to create a kws pipeline for prediction Args: @@ -31,13 +36,20 @@ class TextToImageSynthesisPipeline(Pipeline): else: raise NotImplementedError( f'expecting a Model instance or str, but get {type(model)}.') - - super().__init__(model=pipe_model, **kwargs) - - def preprocess(self, input: Input) -> Dict[str, Any]: - return input + if preprocessor is None and isinstance(pipe_model, + OfaForTextToImageSynthesis): + preprocessor = OfaPreprocessor(pipe_model.model_dir) + super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) + + def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]: + if self.preprocessor is not None: + return self.preprocessor(input, **preprocess_params) + else: + return input def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(self.model, OfaForTextToImageSynthesis): + return self.model(input) return self.model.generate(input) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 2f62c6af..f3bba772 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -23,6 +23,8 @@ __all__ = [ @PREPROCESSORS.register_module( Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.ofa_text_to_image_synthesis) class OfaPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): @@ -40,7 +42,8 @@ class OfaPreprocessor(Preprocessor): Tasks.visual_entailment: OfaVisualEntailmentPreprocessor, Tasks.image_classification: OfaImageClassificationPreprocessor, Tasks.text_classification: OfaTextClassificationPreprocessor, - Tasks.summarization: OfaSummarizationPreprocessor + Tasks.summarization: OfaSummarizationPreprocessor, + Tasks.text_to_image_synthesis: OfaTextToImageSynthesisPreprocessor } input_key_mapping = { Tasks.image_captioning: ['image'], @@ -50,6 +53,7 @@ class OfaPreprocessor(Preprocessor): Tasks.visual_grounding: ['image', 'text'], Tasks.visual_question_answering: ['image', 'text'], Tasks.visual_entailment: ['image', 'text', 'text2'], + Tasks.text_to_image_synthesis: ['text'] } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) diff --git a/modelscope/preprocessors/ofa/__init__.py b/modelscope/preprocessors/ofa/__init__.py index 44954668..95d72fe1 100644 --- a/modelscope/preprocessors/ofa/__init__.py +++ b/modelscope/preprocessors/ofa/__init__.py @@ -3,6 +3,7 @@ from .image_captioning import OfaImageCaptioningPreprocessor from .image_classification import OfaImageClassificationPreprocessor from .summarization import OfaSummarizationPreprocessor from .text_classification import OfaTextClassificationPreprocessor +from .text_to_image_synthesis import OfaTextToImageSynthesisPreprocessor from .visual_entailment import OfaVisualEntailmentPreprocessor from .visual_grounding import OfaVisualGroundingPreprocessor from .visual_question_answering import OfaVisualQuestionAnsweringPreprocessor diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py new file mode 100644 index 00000000..9dbba921 --- /dev/null +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +import torch + +from .base import OfaBasePreprocessor + + +class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): + + def __init__(self, cfg, model_dir): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + super(OfaTextToImageSynthesisPreprocessor, + self).__init__(cfg, model_dir) + self.max_src_length = 64 + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + source = data['text'].lower().strip().split()[:self.max_src_length] + source = 'what is the complete image? caption: {}'.format(source) + inputs = self.get_inputs(source) + sample = { + 'source': inputs, + 'patch_images': None, + 'patch_masks': torch.tensor([False]), + 'code_masks': torch.tensor([False]) + } + return sample diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 5bc7abd5..ef5d4341 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -6,6 +6,7 @@ pycocotools>=2.0.4 # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 +taming-transformers-rom1504 timm tokenizers torchvision diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 5cba86b1..a2b23e48 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -244,6 +244,25 @@ class OfaTasksTest(unittest.TestCase): result = ofa_pipe(input) print(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_to_image_synthesis_with_name(self): + model = 'damo/ofa_text-to-image-synthesis_coco_large_en' + ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + example = {'text': 'a bear in the water.'} + result = ofa_pipe(example) + result[OutputKeys.OUTPUT_IMG].save('result.png') + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_text_to_image_synthesis_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_text-to-image-synthesis_coco_large_en') + ofa_pipe = pipeline(Tasks.text_to_image_synthesis, model=model) + example = {'text': 'a bear in the water.'} + result = ofa_pipe(example) + result[OutputKeys.OUTPUT_IMG].save('result.png') + print(f'Output written to {osp.abspath("result.png")}') + if __name__ == '__main__': unittest.main()