|
|
|
@@ -1,12 +1,13 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import os.path as osp |
|
|
|
from typing import Any, Dict, Union |
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
|
|
import torch |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
from modelscope.metainfo import Preprocessors |
|
|
|
from modelscope.pipelines.base import Input |
|
|
|
from modelscope.utils.config import Config |
|
|
|
from modelscope.utils.constant import Fields, ModelFile, Tasks |
|
|
|
from .base import Preprocessor |
|
|
|
@@ -41,13 +42,39 @@ class OfaPreprocessor(Preprocessor): |
|
|
|
Tasks.text_classification: OfaTextClassificationPreprocessor, |
|
|
|
Tasks.summarization: OfaSummarizationPreprocessor |
|
|
|
} |
|
|
|
input_key_mapping = { |
|
|
|
Tasks.image_captioning: ['image'], |
|
|
|
Tasks.image_classification: ['image'], |
|
|
|
Tasks.summarization: ['text'], |
|
|
|
Tasks.text_classification: ['text', 'text2'], |
|
|
|
Tasks.visual_grounding: ['image', 'text'], |
|
|
|
Tasks.visual_question_answering: ['image', 'text'], |
|
|
|
Tasks.visual_entailment: ['image', 'text', 'text2'], |
|
|
|
} |
|
|
|
model_dir = model_dir if osp.exists(model_dir) else snapshot_download( |
|
|
|
model_dir) |
|
|
|
cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION)) |
|
|
|
self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir) |
|
|
|
self.cfg = Config.from_file( |
|
|
|
osp.join(model_dir, ModelFile.CONFIGURATION)) |
|
|
|
self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, |
|
|
|
model_dir) |
|
|
|
self.keys = input_key_mapping[self.cfg.task] |
|
|
|
self.tokenizer = self.preprocess.tokenizer |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
# just for modelscope demo |
|
|
|
def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: |
|
|
|
data = dict() |
|
|
|
if not isinstance(input, tuple) and not isinstance(input, list): |
|
|
|
input = (input, ) |
|
|
|
for key, item in zip(self.keys, input): |
|
|
|
data[key] = item |
|
|
|
return data |
|
|
|
|
|
|
|
def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, |
|
|
|
**kwargs) -> Dict[str, Any]: |
|
|
|
if isinstance(input, dict): |
|
|
|
data = input |
|
|
|
else: |
|
|
|
data = self._build_dict(input) |
|
|
|
sample = self.preprocess(data) |
|
|
|
sample['sample'] = data |
|
|
|
return collate_fn([sample], |
|
|
|
|