yichang.zyc yingda.chen 3 years ago
parent
commit
32c2adb650
12 changed files with 40 additions and 32 deletions
  1. +2
    -3
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/models/multi_modal/__init__.py
  3. +8
    -0
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  4. +2
    -1
      modelscope/pipelines/cv/image_classification_pipeline.py
  5. +2
    -1
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py
  6. +2
    -1
      modelscope/pipelines/multi_modal/visual_entailment_pipeline.py
  7. +2
    -1
      modelscope/pipelines/multi_modal/visual_grounding_pipeline.py
  8. +9
    -4
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py
  9. +2
    -1
      modelscope/pipelines/nlp/summarization_pipeline.py
  10. +2
    -1
      modelscope/pipelines/nlp/text_classification_pipeline.py
  11. +1
    -3
      modelscope/preprocessors/multi_modal.py
  12. +7
    -16
      tests/pipelines/test_ofa_tasks.py

+ 2
- 3
modelscope/metainfo.py View File

@@ -201,9 +201,8 @@ class Preprocessors(object):
wav_to_lists = 'wav-to-lists' wav_to_lists = 'wav-to-lists'
wav_to_scp = 'wav-to-scp' wav_to_scp = 'wav-to-scp'


# multi-modal
ofa_image_caption = 'ofa-image-caption'
ofa_text_to_image_synthesis = 'ofa-text-to-image-synthesis'
# multi-modal preprocessor
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
mplug_visual_question_answering = 'mplug-visual-question-answering' mplug_visual_question_answering = 'mplug-visual-question-answering'






+ 1
- 0
modelscope/models/multi_modal/__init__.py View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .mmr import VideoCLIPForMultiModalEmbedding from .mmr import VideoCLIPForMultiModalEmbedding
from .mplug_for_visual_question_answering import \ from .mplug_for_visual_question_answering import \
MPlugForVisualQuestionAnswering MPlugForVisualQuestionAnswering
from .ofa_for_all_tasks import OfaForAllTasks


else: else:
_import_structure = { _import_structure = {


+ 8
- 0
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import math import math
import string
from os import path as osp from os import path as osp
from typing import Any, Dict from typing import Any, Dict


@@ -58,6 +59,9 @@ class OfaForAllTasks(TorchModel):
self.max_image_size = self.cfg.model.get('max_image_size', 512) self.max_image_size = self.cfg.model.get('max_image_size', 512)
self.val_batch_size = self.cfg.model.get('valid_batch_size', self.val_batch_size = self.cfg.model.get('valid_batch_size',
self.batch_size) self.batch_size)
self.transtab = str.maketrans(
{key: None
for key in string.punctuation})
self.gen_type = self.cfg.model.get('gen_type', 'generation') self.gen_type = self.cfg.model.get('gen_type', 'generation')
assert self.gen_type in ['generation', 'traverse'], \ assert self.gen_type in ['generation', 'traverse'], \
'model.gen_type must be in ["generation", "traverse"]' 'model.gen_type must be in ["generation", "traverse"]'
@@ -116,6 +120,10 @@ class OfaForAllTasks(TorchModel):


def postprocess(self, input: Dict[str, Tensor], def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]: **kwargs) -> Dict[str, Tensor]:
if self.cfg.task == Tasks.image_captioning:
caption = input[OutputKeys.CAPTION]
caption = caption.translate(self.transtab).strip()
input[OutputKeys.CAPTION] = caption
return input return input


def _text_gen_inference(self, input): def _text_gen_inference(self, input):


+ 2
- 1
modelscope/pipelines/cv/image_classification_pipeline.py View File

@@ -7,6 +7,7 @@ import PIL
import torch import torch


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
@@ -35,7 +36,7 @@ class ImageClassificationPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 2
- 1
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.preprocessors import OfaPreprocessor, Preprocessor
@@ -34,7 +35,7 @@ class ImageCaptioningPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 2
- 1
modelscope/pipelines/multi_modal/visual_entailment_pipeline.py View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.preprocessors import OfaPreprocessor, Preprocessor
@@ -34,7 +35,7 @@ class VisualEntailmentPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 2
- 1
modelscope/pipelines/multi_modal/visual_grounding_pipeline.py View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.preprocessors import OfaPreprocessor, Preprocessor
@@ -34,7 +35,7 @@ class VisualGroundingPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 9
- 4
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -5,11 +5,13 @@ import torch


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models import Model from modelscope.models import Model
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
from modelscope.models.multi_modal import (MPlugForVisualQuestionAnswering,
OfaForAllTasks)
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
from modelscope.preprocessors import (MPlugVisualQuestionAnsweringPreprocessor,
OfaPreprocessor)
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks


__all__ = ['VisualQuestionAnsweringPipeline'] __all__ = ['VisualQuestionAnsweringPipeline']
@@ -35,8 +37,11 @@ class VisualQuestionAnsweringPipeline(Pipeline):
Model) else Model.from_pretrained(model) Model) else Model.from_pretrained(model)
self.tokenizer = None self.tokenizer = None
if preprocessor is None: if preprocessor is None:
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
if isinstance(model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model.model_dir)
elif isinstance(model, MPlugForVisualQuestionAnswering):
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
if isinstance(model, MPlugForVisualQuestionAnswering): if isinstance(model, MPlugForVisualQuestionAnswering):
model.eval() model.eval()
self.tokenizer = model.tokenizer self.tokenizer = model.tokenizer


+ 2
- 1
modelscope/pipelines/nlp/summarization_pipeline.py View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.preprocessors import OfaPreprocessor, Preprocessor
@@ -34,7 +35,7 @@ class SummarizationPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 2
- 1
modelscope/pipelines/nlp/text_classification_pipeline.py View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal import OfaForAllTasks
from modelscope.pipelines.base import Model, Pipeline from modelscope.pipelines.base import Model, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor from modelscope.preprocessors import OfaPreprocessor, Preprocessor
@@ -34,7 +35,7 @@ class TextClassificationPipeline(Pipeline):
else: else:
raise NotImplementedError raise NotImplementedError
pipe_model.model.eval() pipe_model.model.eval()
if preprocessor is None and pipe_model:
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)




+ 1
- 3
modelscope/preprocessors/multi_modal.py View File

@@ -22,9 +22,7 @@ __all__ = [




@PREPROCESSORS.register_module( @PREPROCESSORS.register_module(
Fields.multi_modal, module_name=Preprocessors.ofa_image_caption)
@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=Preprocessors.ofa_text_to_image_synthesis)
Fields.multi_modal, module_name=Preprocessors.ofa_tasks_preprocessor)
class OfaPreprocessor(Preprocessor): class OfaPreprocessor(Preprocessor):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):


+ 7
- 16
tests/pipelines/test_ofa_tasks.py View File

@@ -35,15 +35,15 @@ class OfaTasksTest(unittest.TestCase):
task=Tasks.image_captioning, task=Tasks.image_captioning,
model=model, model=model,
) )
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
image = 'data/test/images/image_captioning.png'
result = img_captioning({'image': image})
print(result[OutputKeys.CAPTION]) print(result[OutputKeys.CAPTION])


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_with_name(self): def test_run_with_image_captioning_with_name(self):
img_captioning = pipeline( img_captioning = pipeline(
Tasks.image_captioning, Tasks.image_captioning,
model='damo/ofa_image-caption_coco_distilled_en')
model='damo/ofa_image-caption_coco_large_en')
result = img_captioning( result = img_captioning(
{'image': 'data/test/images/image_captioning.png'}) {'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION]) print(result[OutputKeys.CAPTION])
@@ -181,14 +181,9 @@ class OfaTasksTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_question_answering_with_model(self): def test_run_with_visual_question_answering_with_model(self):
from modelscope.preprocessors.multi_modal import OfaPreprocessor
model = Model.from_pretrained( model = Model.from_pretrained(
'damo/ofa_visual-question-answering_pretrain_large_en') 'damo/ofa_visual-question-answering_pretrain_large_en')
preprocessor = OfaPreprocessor(model_dir=model.model_dir)
ofa_pipe = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
ofa_pipe = pipeline(Tasks.visual_question_answering, model=model)
image = 'data/test/images/visual_question_answering.png' image = 'data/test/images/visual_question_answering.png'
text = 'what is grown on the plant?' text = 'what is grown on the plant?'
input = {'image': image, 'text': text} input = {'image': image, 'text': text}
@@ -197,13 +192,8 @@ class OfaTasksTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_visual_question_answering_with_name(self): def test_run_with_visual_question_answering_with_name(self):
from modelscope.preprocessors.multi_modal import OfaPreprocessor
model = 'damo/ofa_visual-question-answering_pretrain_large_en' model = 'damo/ofa_visual-question-answering_pretrain_large_en'
preprocessor = OfaPreprocessor(model_dir=model)
ofa_pipe = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
ofa_pipe = pipeline(Tasks.visual_question_answering, model=model)
image = 'data/test/images/visual_question_answering.png' image = 'data/test/images/visual_question_answering.png'
text = 'what is grown on the plant?' text = 'what is grown on the plant?'
input = {'image': image, 'text': text} input = {'image': image, 'text': text}
@@ -218,7 +208,8 @@ class OfaTasksTest(unittest.TestCase):
task=Tasks.image_captioning, task=Tasks.image_captioning,
model=model, model=model,
) )
image = Image.open('data/test/images/image_captioning.png')
image_path = 'data/test/images/image_captioning.png'
image = Image.open(image_path)
result = img_captioning(image) result = img_captioning(image)
print(result[OutputKeys.CAPTION]) print(result[OutputKeys.CAPTION])




Loading…
Cancel
Save