Browse Source

[to #42322933] add init and make demo compatible link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9606656

add init and make demo compatible
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9606656

    * add init and make demo compatible

* make demo compatible

* fix comments

* add distilled ut

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix
master
yichang.zyc 3 years ago
parent
commit
21de1e7db0
4 changed files with 67 additions and 7 deletions
  1. +0
    -0
      modelscope/models/multi_modal/ofa/utils/__init__.py
  2. +1
    -1
      modelscope/pipelines/cv/__init__.py
  3. +31
    -4
      modelscope/preprocessors/multi_modal.py
  4. +35
    -2
      tests/pipelines/test_ofa_tasks.py

+ 0
- 0
modelscope/models/multi_modal/ofa/utils/__init__.py View File


+ 1
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from .image_matting_pipeline import ImageMattingPipeline
from .image_style_transfer_pipeline import ImageStyleTransferPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline
from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline
from .live_category_pipeline import LiveCategoryPipeline


+ 31
- 4
modelscope/preprocessors/multi_modal.py View File

@@ -1,12 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

import torch
from PIL import Image

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors
from modelscope.pipelines.base import Input
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks
from .base import Preprocessor
@@ -41,13 +42,39 @@ class OfaPreprocessor(Preprocessor):
Tasks.text_classification: OfaTextClassificationPreprocessor,
Tasks.summarization: OfaSummarizationPreprocessor
}
input_key_mapping = {
Tasks.image_captioning: ['image'],
Tasks.image_classification: ['image'],
Tasks.summarization: ['text'],
Tasks.text_classification: ['text', 'text2'],
Tasks.visual_grounding: ['image', 'text'],
Tasks.visual_question_answering: ['image', 'text'],
Tasks.visual_entailment: ['image', 'text', 'text2'],
}
model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
model_dir)
cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg,
model_dir)
self.keys = input_key_mapping[self.cfg.task]
self.tokenizer = self.preprocess.tokenizer

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# just for modelscope demo
def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]:
data = dict()
if not isinstance(input, tuple) and not isinstance(input, list):
input = (input, )
for key, item in zip(self.keys, input):
data[key] = item
return data

def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
**kwargs) -> Dict[str, Any]:
if isinstance(input, dict):
data = input
else:
data = self._build_dict(input)
sample = self.preprocess(data)
sample['sample'] = data
return collate_fn([sample],


+ 35
- 2
tests/pipelines/test_ofa_tasks.py View File

@@ -12,8 +12,7 @@ class OfaTasksTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-caption_coco_distilled_en')
model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en')
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
@@ -174,6 +173,40 @@ class OfaTasksTest(unittest.TestCase):
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_distilled_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-caption_coco_distilled_en')
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
)
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_entailment_distilled_model_with_name(self):
ofa_pipe = pipeline(
Tasks.visual_entailment,
model='damo/ofa_visual-entailment_snli-ve_distilled_v2_en')
image = 'data/test/images/dogs.jpg'
text = 'there are two birds.'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_distilled_model_with_model(self):
model = Model.from_pretrained(
'damo/ofa_visual-grounding_refcoco_distilled_en')
ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
image = 'data/test/images/visual_grounding.png'
text = 'a blue turtle-like pokemon with round head'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save