diff --git a/modelscope/models/multi_modal/ofa/utils/__init__.py b/modelscope/models/multi_modal/ofa/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index e675fe81..38593a65 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from .image_matting_pipeline import ImageMattingPipeline from .image_style_transfer_pipeline import ImageStyleTransferPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline - from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline + from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .live_category_pipeline import LiveCategoryPipeline diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 055c4efb..a3411a73 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,12 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union import torch from PIL import Image from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors +from modelscope.pipelines.base import Input from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile, Tasks from .base import Preprocessor @@ -41,13 +42,39 @@ class OfaPreprocessor(Preprocessor): Tasks.text_classification: OfaTextClassificationPreprocessor, Tasks.summarization: OfaSummarizationPreprocessor } + input_key_mapping = { + Tasks.image_captioning: ['image'], + Tasks.image_classification: ['image'], + Tasks.summarization: ['text'], + Tasks.text_classification: ['text', 'text2'], + Tasks.visual_grounding: ['image', 'text'], + Tasks.visual_question_answering: ['image', 'text'], + Tasks.visual_entailment: ['image', 'text', 'text2'], + } model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir) - cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION)) - self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir) + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, + model_dir) + self.keys = input_key_mapping[self.cfg.task] self.tokenizer = self.preprocess.tokenizer - def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + # just for modelscope demo + def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: + data = dict() + if not isinstance(input, tuple) and not isinstance(input, list): + input = (input, ) + for key, item in zip(self.keys, input): + data[key] = item + return data + + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, + **kwargs) -> Dict[str, Any]: + if isinstance(input, dict): + data = input + else: + data = self._build_dict(input) sample = self.preprocess(data) sample['sample'] = data return collate_fn([sample], diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 2c494e40..63efa334 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -12,8 +12,7 @@ class OfaTasksTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_image_captioning_with_model(self): - model = Model.from_pretrained( - 'damo/ofa_image-caption_coco_distilled_en') + model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') img_captioning = pipeline( task=Tasks.image_captioning, model=model, @@ -174,6 +173,40 @@ class OfaTasksTest(unittest.TestCase): result = ofa_pipe(input) print(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_image_captioning_distilled_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_image-caption_coco_distilled_en') + img_captioning = pipeline( + task=Tasks.image_captioning, + model=model, + ) + result = img_captioning( + {'image': 'data/test/images/image_captioning.png'}) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_entailment_distilled_model_with_name(self): + ofa_pipe = pipeline( + Tasks.visual_entailment, + model='damo/ofa_visual-entailment_snli-ve_distilled_v2_en') + image = 'data/test/images/dogs.jpg' + text = 'there are two birds.' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_visual_grounding_distilled_model_with_model(self): + model = Model.from_pretrained( + 'damo/ofa_visual-grounding_refcoco_distilled_en') + ofa_pipe = pipeline(Tasks.visual_grounding, model=model) + image = 'data/test/images/visual_grounding.png' + text = 'a blue turtle-like pokemon with round head' + input = {'image': image, 'text': text} + result = ofa_pipe(input) + print(result) + if __name__ == '__main__': unittest.main()