Browse Source

[to #42322933] add init and make demo compatible link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9606656

add init and make demo compatible
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9606656

    * add init and make demo compatible

* make demo compatible

* fix comments

* add distilled ut

* Merge remote-tracking branch 'origin/master' into ofa/bug_fix
master
yichang.zyc 3 years ago
parent
commit
21de1e7db0
4 changed files with 67 additions and 7 deletions
  1. +0
    -0
      modelscope/models/multi_modal/ofa/utils/__init__.py
  2. +1
    -1
      modelscope/pipelines/cv/__init__.py
  3. +31
    -4
      modelscope/preprocessors/multi_modal.py
  4. +35
    -2
      tests/pipelines/test_ofa_tasks.py

+ 0
- 0
modelscope/models/multi_modal/ofa/utils/__init__.py View File


+ 1
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from .image_matting_pipeline import ImageMattingPipeline from .image_matting_pipeline import ImageMattingPipeline
from .image_style_transfer_pipeline import ImageStyleTransferPipeline from .image_style_transfer_pipeline import ImageStyleTransferPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline
from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline
from .live_category_pipeline import LiveCategoryPipeline from .live_category_pipeline import LiveCategoryPipeline


+ 31
- 4
modelscope/preprocessors/multi_modal.py View File

@@ -1,12 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp import os.path as osp
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union


import torch import torch
from PIL import Image from PIL import Image


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors from modelscope.metainfo import Preprocessors
from modelscope.pipelines.base import Input
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModelFile, Tasks from modelscope.utils.constant import Fields, ModelFile, Tasks
from .base import Preprocessor from .base import Preprocessor
@@ -41,13 +42,39 @@ class OfaPreprocessor(Preprocessor):
Tasks.text_classification: OfaTextClassificationPreprocessor, Tasks.text_classification: OfaTextClassificationPreprocessor,
Tasks.summarization: OfaSummarizationPreprocessor Tasks.summarization: OfaSummarizationPreprocessor
} }
input_key_mapping = {
Tasks.image_captioning: ['image'],
Tasks.image_classification: ['image'],
Tasks.summarization: ['text'],
Tasks.text_classification: ['text', 'text2'],
Tasks.visual_grounding: ['image', 'text'],
Tasks.visual_question_answering: ['image', 'text'],
Tasks.visual_entailment: ['image', 'text', 'text2'],
}
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( model_dir = model_dir if osp.exists(model_dir) else snapshot_download(
model_dir) model_dir)
cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg,
model_dir)
self.keys = input_key_mapping[self.cfg.task]
self.tokenizer = self.preprocess.tokenizer self.tokenizer = self.preprocess.tokenizer


def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# just for modelscope demo
def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]:
data = dict()
if not isinstance(input, tuple) and not isinstance(input, list):
input = (input, )
for key, item in zip(self.keys, input):
data[key] = item
return data

def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args,
**kwargs) -> Dict[str, Any]:
if isinstance(input, dict):
data = input
else:
data = self._build_dict(input)
sample = self.preprocess(data) sample = self.preprocess(data)
sample['sample'] = data sample['sample'] = data
return collate_fn([sample], return collate_fn([sample],


+ 35
- 2
tests/pipelines/test_ofa_tasks.py View File

@@ -12,8 +12,7 @@ class OfaTasksTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_model(self): def test_run_with_image_captioning_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-caption_coco_distilled_en')
model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en')
img_captioning = pipeline( img_captioning = pipeline(
task=Tasks.image_captioning, task=Tasks.image_captioning,
model=model, model=model,
@@ -174,6 +173,40 @@ class OfaTasksTest(unittest.TestCase):
result = ofa_pipe(input) result = ofa_pipe(input)
print(result) print(result)


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_distilled_with_model(self):
model = Model.from_pretrained(
'damo/ofa_image-caption_coco_distilled_en')
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
)
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_entailment_distilled_model_with_name(self):
ofa_pipe = pipeline(
Tasks.visual_entailment,
model='damo/ofa_visual-entailment_snli-ve_distilled_v2_en')
image = 'data/test/images/dogs.jpg'
text = 'there are two birds.'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_grounding_distilled_model_with_model(self):
model = Model.from_pretrained(
'damo/ofa_visual-grounding_refcoco_distilled_en')
ofa_pipe = pipeline(Tasks.visual_grounding, model=model)
image = 'data/test/images/visual_grounding.png'
text = 'a blue turtle-like pokemon with round head'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save