* format pipeline output and check it * fix UT * add docstr to clarify the difference between model.postprocess and pipeline.postprocess Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051405master
| @@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) | |||||
| # CUDA_VERSION = 11.3 | # CUDA_VERSION = 11.3 | ||||
| # CUDNN_VERSION = 8 | # CUDNN_VERSION = 8 | ||||
| BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | ||||
| BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||||
| # BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||||
| BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||||
| MODELSCOPE_VERSION = $(shell git describe --tags --always) | MODELSCOPE_VERSION = $(shell git describe --tags --always) | ||||
| @@ -8,13 +8,29 @@ | |||||
| # For reference: | # For reference: | ||||
| # https://docs.docker.com/develop/develop-images/build_enhancements/ | # https://docs.docker.com/develop/develop-images/build_enhancements/ | ||||
| #ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||||
| #FROM ${BASE_IMAGE} as dev-base | |||||
| # ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||||
| # FROM ${BASE_IMAGE} as dev-base | |||||
| FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||||
| # FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||||
| FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||||
| # FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime | |||||
| # config pip source | # config pip source | ||||
| RUN mkdir /root/.pip | RUN mkdir /root/.pip | ||||
| COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf | COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf | ||||
| COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list | |||||
| # Install essential Ubuntu packages | |||||
| RUN apt-get update &&\ | |||||
| apt-get install -y software-properties-common \ | |||||
| build-essential \ | |||||
| git \ | |||||
| wget \ | |||||
| vim \ | |||||
| curl \ | |||||
| zip \ | |||||
| zlib1g-dev \ | |||||
| unzip \ | |||||
| pkg-config | |||||
| # install modelscope and its python env | # install modelscope and its python env | ||||
| WORKDIR /opt/modelscope | WORKDIR /opt/modelscope | ||||
| @@ -20,16 +20,24 @@ class Model(ABC): | |||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| return self.post_process(self.forward(input)) | |||||
| return self.postprocess(self.forward(input)) | |||||
| @abstractmethod | @abstractmethod | ||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| pass | pass | ||||
| def post_process(self, input: Dict[str, Tensor], | |||||
| **kwargs) -> Dict[str, Tensor]: | |||||
| # model specific postprocess, implementation is optional | |||||
| # will be called in Pipeline and evaluation loop(in the future) | |||||
| def postprocess(self, input: Dict[str, Tensor], | |||||
| **kwargs) -> Dict[str, Tensor]: | |||||
| """ Model specific postprocess and convert model output to | |||||
| standard model outputs. | |||||
| Args: | |||||
| inputs: input data | |||||
| Return: | |||||
| dict of results: a dict containing outputs of model, each | |||||
| output should have the standard output name. | |||||
| """ | |||||
| return input | return input | ||||
| @classmethod | @classmethod | ||||
| @@ -1,5 +1,7 @@ | |||||
| import os | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -34,6 +36,11 @@ class BertForSequenceClassification(Model): | |||||
| ('token_type_ids', torch.LongTensor)], | ('token_type_ids', torch.LongTensor)], | ||||
| output_keys=['predictions', 'probabilities', 'logits']) | output_keys=['predictions', 'probabilities', 'logits']) | ||||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||||
| with open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -50,3 +57,13 @@ class BertForSequenceClassification(Model): | |||||
| } | } | ||||
| """ | """ | ||||
| return self.model.predict(input) | return self.model.predict(input) | ||||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||||
| **kwargs) -> Dict[str, np.ndarray]: | |||||
| # N x num_classes | |||||
| probs = inputs['probabilities'] | |||||
| result = { | |||||
| 'probs': probs, | |||||
| } | |||||
| return result | |||||
| @@ -12,6 +12,7 @@ from modelscope.pydatasets import PyDataset | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.hub import get_model_cache_dir | from modelscope.utils.hub import get_model_cache_dir | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .outputs import TASK_OUTPUTS | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -106,8 +107,25 @@ class Pipeline(ABC): | |||||
| out = self.preprocess(input) | out = self.preprocess(input) | ||||
| out = self.forward(out) | out = self.forward(out) | ||||
| out = self.postprocess(out, **post_kwargs) | out = self.postprocess(out, **post_kwargs) | ||||
| self._check_output(out) | |||||
| return out | return out | ||||
| def _check_output(self, input): | |||||
| # this attribute is dynamically attached by registry | |||||
| # when cls is registered in registry using task name | |||||
| task_name = self.group_key | |||||
| if task_name not in TASK_OUTPUTS: | |||||
| logger.warning(f'task {task_name} output keys are missing') | |||||
| return | |||||
| output_keys = TASK_OUTPUTS[task_name] | |||||
| missing_keys = [] | |||||
| for k in output_keys: | |||||
| if k not in input: | |||||
| missing_keys.append(k) | |||||
| if len(missing_keys) > 0: | |||||
| raise ValueError(f'expected output keys are {output_keys}, ' | |||||
| f'those {missing_keys} are missing') | |||||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | def preprocess(self, inputs: Input) -> Dict[str, Any]: | ||||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | """ Provide default implementation based on preprocess_cfg and user can reimplement it | ||||
| """ | """ | ||||
| @@ -125,4 +143,14 @@ class Pipeline(ABC): | |||||
| @abstractmethod | @abstractmethod | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| """ If current pipeline support model reuse, common postprocess | |||||
| code should be write here. | |||||
| Args: | |||||
| inputs: input data | |||||
| Return: | |||||
| dict of results: a dict containing outputs of model, each | |||||
| output should have the standard output name. | |||||
| """ | |||||
| raise NotImplementedError('postprocess') | raise NotImplementedError('postprocess') | ||||
| @@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline): | |||||
| second_sequence=None) | second_sequence=None) | ||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | ||||
| from easynlp.utils import io | |||||
| self.label_path = os.path.join(sc_model.model_dir, | |||||
| 'label_mapping.json') | |||||
| with io.open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.label_id_to_name = { | |||||
| idx: name | |||||
| for name, idx in self.label_mapping.items() | |||||
| } | |||||
| assert hasattr(self.model, 'id2label'), \ | |||||
| 'id2label map should be initalizaed in init function.' | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| def postprocess(self, | |||||
| inputs: Dict[str, Any], | |||||
| topk: int = 5) -> Dict[str, str]: | |||||
| """process the prediction results | """process the prediction results | ||||
| Args: | Args: | ||||
| inputs (Dict[str, Any]): _description_ | |||||
| inputs (Dict[str, Any]): input data dict | |||||
| topk (int): return topk classification result. | |||||
| Returns: | Returns: | ||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| # NxC np.ndarray | |||||
| probs = inputs['probs'][0] | |||||
| num_classes = probs.shape[0] | |||||
| topk = min(topk, num_classes) | |||||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||||
| probs = probs[cls_ids].tolist() | |||||
| probs = inputs['probabilities'] | |||||
| logits = inputs['logits'] | |||||
| predictions = np.argsort(-probs, axis=-1) | |||||
| preds = predictions[0] | |||||
| b = 0 | |||||
| new_result = list() | |||||
| for pred in preds: | |||||
| new_result.append({ | |||||
| 'pred': self.label_id_to_name[pred], | |||||
| 'prob': float(probs[b][pred]), | |||||
| 'logit': float(logits[b][pred]) | |||||
| }) | |||||
| new_results = list() | |||||
| new_results.append({ | |||||
| 'id': | |||||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||||
| 'output': | |||||
| new_result, | |||||
| 'predictions': | |||||
| new_result[0]['pred'], | |||||
| 'probabilities': | |||||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||||
| 'logits': | |||||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||||
| }) | |||||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||||
| return new_results[0] | |||||
| return {'scores': probs, 'labels': cls_names} | |||||
| @@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline): | |||||
| '').split('[SEP]')[0].replace('[CLS]', | '').split('[SEP]')[0].replace('[CLS]', | ||||
| '').replace('[SEP]', | '').replace('[SEP]', | ||||
| '').replace('[UNK]', '') | '').replace('[UNK]', '') | ||||
| return {'pred_string': pred_string} | |||||
| return {'text': pred_string} | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from modelscope.utils.constant import Tasks | |||||
| TASK_OUTPUTS = { | |||||
| # ============ vision tasks =================== | |||||
| # image classification result for single sample | |||||
| # { | |||||
| # "labels": ["dog", "horse", "cow", "cat"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.image_classification: ['scores', 'labels'], | |||||
| Tasks.image_tagging: ['scores', 'labels'], | |||||
| # object detection result for single sample | |||||
| # { | |||||
| # "boxes": [ | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # ], | |||||
| # "labels": ["dog", "horse", "cow", "cat"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.object_detection: ['scores', 'labels', 'boxes'], | |||||
| # instance segmentation result for single sample | |||||
| # { | |||||
| # "masks": [ | |||||
| # np.array in bgr channel order | |||||
| # ], | |||||
| # "labels": ["dog", "horse", "cow", "cat"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.image_segmentation: ['scores', 'labels', 'boxes'], | |||||
| # image generation/editing/matting result for single sample | |||||
| # { | |||||
| # "output_png": np.array with shape(h, w, 4) | |||||
| # for matting or (h, w, 3) for general purpose | |||||
| # } | |||||
| Tasks.image_editing: ['output_png'], | |||||
| Tasks.image_matting: ['output_png'], | |||||
| Tasks.image_generation: ['output_png'], | |||||
| # pose estimation result for single sample | |||||
| # { | |||||
| # "poses": np.array with shape [num_pose, num_keypoint, 3], | |||||
| # each keypoint is a array [x, y, score] | |||||
| # "boxes": np.array with shape [num_pose, 4], each box is | |||||
| # [x1, y1, x2, y2] | |||||
| # } | |||||
| Tasks.pose_estimation: ['poses', 'boxes'], | |||||
| # ============ nlp tasks =================== | |||||
| # text classification result for single sample | |||||
| # { | |||||
| # "labels": ["happy", "sad", "calm", "angry"], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.text_classification: ['scores', 'labels'], | |||||
| # text generation result for single sample | |||||
| # { | |||||
| # "text": "this is text generated by a model." | |||||
| # } | |||||
| Tasks.text_generation: ['text'], | |||||
| # ============ audio tasks =================== | |||||
| # ============ multi-modal tasks =================== | |||||
| # image caption result for single sample | |||||
| # { | |||||
| # "caption": "this is an image caption text." | |||||
| # } | |||||
| Tasks.image_captioning: ['caption'], | |||||
| # visual grounding result for single sample | |||||
| # { | |||||
| # "boxes": [ | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # [x1, y1, x2, y2], | |||||
| # ], | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||||
| # } | |||||
| Tasks.visual_grounding: ['boxes', 'scores'], | |||||
| # text_to_image result for a single sample | |||||
| # { | |||||
| # "image": np.ndarray with shape [height, width, 3] | |||||
| # } | |||||
| Tasks.text_to_image_synthesis: ['image'] | |||||
| } | |||||
| @@ -51,7 +51,7 @@ class Tasks(object): | |||||
| text_to_speech = 'text-to-speech' | text_to_speech = 'text-to-speech' | ||||
| speech_signal_process = 'speech-signal-process' | speech_signal_process = 'speech-signal-process' | ||||
| # multi-media | |||||
| # multi-modal tasks | |||||
| image_captioning = 'image-captioning' | image_captioning = 'image-captioning' | ||||
| visual_grounding = 'visual-grounding' | visual_grounding = 'visual-grounding' | ||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| @@ -69,6 +69,7 @@ class Registry(object): | |||||
| f'{self._name}[{group_key}]') | f'{self._name}[{group_key}]') | ||||
| self._modules[group_key][module_name] = module_cls | self._modules[group_key][module_name] = module_cls | ||||
| module_cls.group_key = group_key | |||||
| if module_name in self._modules[default_group]: | if module_name in self._modules[default_group]: | ||||
| if id(self._modules[default_group][module_name]) == id(module_cls): | if id(self._modules[default_group][module_name]) == id(module_cls): | ||||
| @@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase): | |||||
| CustomPipeline1() | CustomPipeline1() | ||||
| def test_custom(self): | def test_custom(self): | ||||
| dummy_task = 'dummy-task' | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| group_key=Tasks.image_tagging, module_name='custom-image') | |||||
| group_key=dummy_task, module_name='custom-image') | |||||
| class CustomImagePipeline(Pipeline): | class CustomImagePipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase): | |||||
| outputs['filename'] = inputs['url'] | outputs['filename'] = inputs['url'] | ||||
| img = inputs['img'] | img = inputs['img'] | ||||
| new_image = img.resize((img.width // 2, img.height // 2)) | new_image = img.resize((img.width // 2, img.height // 2)) | ||||
| outputs['resize_image'] = np.array(new_image) | |||||
| outputs['dummy_result'] = 'dummy_result' | |||||
| outputs['output_png'] = np.array(new_image) | |||||
| return outputs | return outputs | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| return inputs | return inputs | ||||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | ||||
| add_default_pipeline_info(Tasks.image_tagging, 'custom-image') | |||||
| add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) | |||||
| pipe = pipeline(pipeline_name='custom-image') | pipe = pipeline(pipeline_name='custom-image') | ||||
| pipe2 = pipeline(Tasks.image_tagging) | |||||
| pipe2 = pipeline(dummy_task) | |||||
| self.assertTrue(type(pipe) is type(pipe2)) | self.assertTrue(type(pipe) is type(pipe2)) | ||||
| img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | ||||
| 'aliyuncs.com/data/test/images/image1.jpg' | 'aliyuncs.com/data/test/images/image1.jpg' | ||||
| output = pipe(img_url) | output = pipe(img_url) | ||||
| self.assertEqual(output['filename'], img_url) | self.assertEqual(output['filename'], img_url) | ||||
| self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||||
| self.assertEqual(output['dummy_result'], 'dummy_result') | |||||
| self.assertEqual(output['output_png'].shape, (318, 512, 3)) | |||||
| outputs = pipe([img_url for i in range(4)]) | outputs = pipe([img_url for i in range(4)]) | ||||
| self.assertEqual(len(outputs), 4) | self.assertEqual(len(outputs), 4) | ||||
| for out in outputs: | for out in outputs: | ||||
| self.assertEqual(out['filename'], img_url) | self.assertEqual(out['filename'], img_url) | ||||
| self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||||
| self.assertEqual(out['dummy_result'], 'dummy_result') | |||||
| self.assertEqual(out['output_png'].shape, (318, 512, 3)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||