add init and make demo compatible
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9606656
* add init and make demo compatible
* make demo compatible
* fix comments
* add distilled ut
* Merge remote-tracking branch 'origin/master' into ofa/bug_fix
master
| @@ -21,7 +21,7 @@ if TYPE_CHECKING: | |||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | from .image_style_transfer_pipeline import ImageStyleTransferPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline | |||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | |||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | ||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| @@ -1,12 +1,13 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict, List, Union | |||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | from modelscope.utils.constant import Fields, ModelFile, Tasks | ||||
| from .base import Preprocessor | from .base import Preprocessor | ||||
| @@ -41,13 +42,39 @@ class OfaPreprocessor(Preprocessor): | |||||
| Tasks.text_classification: OfaTextClassificationPreprocessor, | Tasks.text_classification: OfaTextClassificationPreprocessor, | ||||
| Tasks.summarization: OfaSummarizationPreprocessor | Tasks.summarization: OfaSummarizationPreprocessor | ||||
| } | } | ||||
| input_key_mapping = { | |||||
| Tasks.image_captioning: ['image'], | |||||
| Tasks.image_classification: ['image'], | |||||
| Tasks.summarization: ['text'], | |||||
| Tasks.text_classification: ['text', 'text2'], | |||||
| Tasks.visual_grounding: ['image', 'text'], | |||||
| Tasks.visual_question_answering: ['image', 'text'], | |||||
| Tasks.visual_entailment: ['image', 'text', 'text2'], | |||||
| } | |||||
| model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | model_dir = model_dir if osp.exists(model_dir) else snapshot_download( | ||||
| model_dir) | model_dir) | ||||
| cfg = Config.from_file(osp.join(model_dir, ModelFile.CONFIGURATION)) | |||||
| self.preprocess = preprocess_mapping[cfg.task](cfg, model_dir) | |||||
| self.cfg = Config.from_file( | |||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||||
| self.preprocess = preprocess_mapping[self.cfg.task](self.cfg, | |||||
| model_dir) | |||||
| self.keys = input_key_mapping[self.cfg.task] | |||||
| self.tokenizer = self.preprocess.tokenizer | self.tokenizer = self.preprocess.tokenizer | ||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| # just for modelscope demo | |||||
| def _build_dict(self, input: Union[Input, List[Input]]) -> Dict[str, Any]: | |||||
| data = dict() | |||||
| if not isinstance(input, tuple) and not isinstance(input, list): | |||||
| input = (input, ) | |||||
| for key, item in zip(self.keys, input): | |||||
| data[key] = item | |||||
| return data | |||||
| def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||||
| **kwargs) -> Dict[str, Any]: | |||||
| if isinstance(input, dict): | |||||
| data = input | |||||
| else: | |||||
| data = self._build_dict(input) | |||||
| sample = self.preprocess(data) | sample = self.preprocess(data) | ||||
| sample['sample'] = data | sample['sample'] = data | ||||
| return collate_fn([sample], | return collate_fn([sample], | ||||
| @@ -12,8 +12,7 @@ class OfaTasksTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_image_captioning_with_model(self): | def test_run_with_image_captioning_with_model(self): | ||||
| model = Model.from_pretrained( | |||||
| 'damo/ofa_image-caption_coco_distilled_en') | |||||
| model = Model.from_pretrained('damo/ofa_image-caption_coco_large_en') | |||||
| img_captioning = pipeline( | img_captioning = pipeline( | ||||
| task=Tasks.image_captioning, | task=Tasks.image_captioning, | ||||
| model=model, | model=model, | ||||
| @@ -174,6 +173,40 @@ class OfaTasksTest(unittest.TestCase): | |||||
| result = ofa_pipe(input) | result = ofa_pipe(input) | ||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_image_captioning_distilled_with_model(self): | |||||
| model = Model.from_pretrained( | |||||
| 'damo/ofa_image-caption_coco_distilled_en') | |||||
| img_captioning = pipeline( | |||||
| task=Tasks.image_captioning, | |||||
| model=model, | |||||
| ) | |||||
| result = img_captioning( | |||||
| {'image': 'data/test/images/image_captioning.png'}) | |||||
| print(result[OutputKeys.CAPTION]) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_visual_entailment_distilled_model_with_name(self): | |||||
| ofa_pipe = pipeline( | |||||
| Tasks.visual_entailment, | |||||
| model='damo/ofa_visual-entailment_snli-ve_distilled_v2_en') | |||||
| image = 'data/test/images/dogs.jpg' | |||||
| text = 'there are two birds.' | |||||
| input = {'image': image, 'text': text} | |||||
| result = ofa_pipe(input) | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_visual_grounding_distilled_model_with_model(self): | |||||
| model = Model.from_pretrained( | |||||
| 'damo/ofa_visual-grounding_refcoco_distilled_en') | |||||
| ofa_pipe = pipeline(Tasks.visual_grounding, model=model) | |||||
| image = 'data/test/images/visual_grounding.png' | |||||
| text = 'a blue turtle-like pokemon with round head' | |||||
| input = {'image': image, 'text': text} | |||||
| result = ofa_pipe(input) | |||||
| print(result) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||