* format pipeline output and check it * fix UT * add docstr to clarify the difference between model.postprocess and pipeline.postprocess Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051405master
@@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) | |||
# CUDA_VERSION = 11.3 | |||
# CUDNN_VERSION = 8 | |||
BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||
MODELSCOPE_VERSION = $(shell git describe --tags --always) | |||
@@ -8,13 +8,29 @@ | |||
# For reference: | |||
# https://docs.docker.com/develop/develop-images/build_enhancements/ | |||
#ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
#FROM ${BASE_IMAGE} as dev-base | |||
# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
# FROM ${BASE_IMAGE} as dev-base | |||
FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||
# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||
# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime | |||
# config pip source | |||
RUN mkdir /root/.pip | |||
COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf | |||
COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list | |||
# Install essential Ubuntu packages | |||
RUN apt-get update &&\ | |||
apt-get install -y software-properties-common \ | |||
build-essential \ | |||
git \ | |||
wget \ | |||
vim \ | |||
curl \ | |||
zip \ | |||
zlib1g-dev \ | |||
unzip \ | |||
pkg-config | |||
# install modelscope and its python env | |||
WORKDIR /opt/modelscope | |||
@@ -20,16 +20,24 @@ class Model(ABC): | |||
self.model_dir = model_dir | |||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
return self.post_process(self.forward(input)) | |||
return self.postprocess(self.forward(input)) | |||
@abstractmethod | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
pass | |||
def post_process(self, input: Dict[str, Tensor], | |||
**kwargs) -> Dict[str, Tensor]: | |||
# model specific postprocess, implementation is optional | |||
# will be called in Pipeline and evaluation loop(in the future) | |||
def postprocess(self, input: Dict[str, Tensor], | |||
**kwargs) -> Dict[str, Tensor]: | |||
""" Model specific postprocess and convert model output to | |||
standard model outputs. | |||
Args: | |||
inputs: input data | |||
Return: | |||
dict of results: a dict containing outputs of model, each | |||
output should have the standard output name. | |||
""" | |||
return input | |||
@classmethod | |||
@@ -1,5 +1,7 @@ | |||
import os | |||
from typing import Any, Dict | |||
import json | |||
import numpy as np | |||
from modelscope.utils.constant import Tasks | |||
@@ -34,6 +36,11 @@ class BertForSequenceClassification(Model): | |||
('token_type_ids', torch.LongTensor)], | |||
output_keys=['predictions', 'probabilities', 'logits']) | |||
self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
with open(self.label_path) as f: | |||
self.label_mapping = json.load(f) | |||
self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
"""return the result by the model | |||
@@ -50,3 +57,13 @@ class BertForSequenceClassification(Model): | |||
} | |||
""" | |||
return self.model.predict(input) | |||
def postprocess(self, inputs: Dict[str, np.ndarray], | |||
**kwargs) -> Dict[str, np.ndarray]: | |||
# N x num_classes | |||
probs = inputs['probabilities'] | |||
result = { | |||
'probs': probs, | |||
} | |||
return result |
@@ -12,6 +12,7 @@ from modelscope.pydatasets import PyDataset | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.logger import get_logger | |||
from .outputs import TASK_OUTPUTS | |||
from .util import is_model_name | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
@@ -106,8 +107,25 @@ class Pipeline(ABC): | |||
out = self.preprocess(input) | |||
out = self.forward(out) | |||
out = self.postprocess(out, **post_kwargs) | |||
self._check_output(out) | |||
return out | |||
def _check_output(self, input): | |||
# this attribute is dynamically attached by registry | |||
# when cls is registered in registry using task name | |||
task_name = self.group_key | |||
if task_name not in TASK_OUTPUTS: | |||
logger.warning(f'task {task_name} output keys are missing') | |||
return | |||
output_keys = TASK_OUTPUTS[task_name] | |||
missing_keys = [] | |||
for k in output_keys: | |||
if k not in input: | |||
missing_keys.append(k) | |||
if len(missing_keys) > 0: | |||
raise ValueError(f'expected output keys are {output_keys}, ' | |||
f'those {missing_keys} are missing') | |||
def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
""" Provide default implementation based on preprocess_cfg and user can reimplement it | |||
""" | |||
@@ -125,4 +143,14 @@ class Pipeline(ABC): | |||
@abstractmethod | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
""" If current pipeline support model reuse, common postprocess | |||
code should be write here. | |||
Args: | |||
inputs: input data | |||
Return: | |||
dict of results: a dict containing outputs of model, each | |||
output should have the standard output name. | |||
""" | |||
raise NotImplementedError('postprocess') |
@@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline): | |||
second_sequence=None) | |||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
from easynlp.utils import io | |||
self.label_path = os.path.join(sc_model.model_dir, | |||
'label_mapping.json') | |||
with io.open(self.label_path) as f: | |||
self.label_mapping = json.load(f) | |||
self.label_id_to_name = { | |||
idx: name | |||
for name, idx in self.label_mapping.items() | |||
} | |||
assert hasattr(self.model, 'id2label'), \ | |||
'id2label map should be initalizaed in init function.' | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
def postprocess(self, | |||
inputs: Dict[str, Any], | |||
topk: int = 5) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
inputs (Dict[str, Any]): input data dict | |||
topk (int): return topk classification result. | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
# NxC np.ndarray | |||
probs = inputs['probs'][0] | |||
num_classes = probs.shape[0] | |||
topk = min(topk, num_classes) | |||
top_indices = np.argpartition(probs, -topk)[-topk:] | |||
cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
probs = probs[cls_ids].tolist() | |||
probs = inputs['probabilities'] | |||
logits = inputs['logits'] | |||
predictions = np.argsort(-probs, axis=-1) | |||
preds = predictions[0] | |||
b = 0 | |||
new_result = list() | |||
for pred in preds: | |||
new_result.append({ | |||
'pred': self.label_id_to_name[pred], | |||
'prob': float(probs[b][pred]), | |||
'logit': float(logits[b][pred]) | |||
}) | |||
new_results = list() | |||
new_results.append({ | |||
'id': | |||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||
'output': | |||
new_result, | |||
'predictions': | |||
new_result[0]['pred'], | |||
'probabilities': | |||
','.join([str(t) for t in inputs['probabilities'][b]]), | |||
'logits': | |||
','.join([str(t) for t in inputs['logits'][b]]) | |||
}) | |||
cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
return new_results[0] | |||
return {'scores': probs, 'labels': cls_names} |
@@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline): | |||
'').split('[SEP]')[0].replace('[CLS]', | |||
'').replace('[SEP]', | |||
'').replace('[UNK]', '') | |||
return {'pred_string': pred_string} | |||
return {'text': pred_string} |
@@ -0,0 +1,98 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.utils.constant import Tasks | |||
TASK_OUTPUTS = { | |||
# ============ vision tasks =================== | |||
# image classification result for single sample | |||
# { | |||
# "labels": ["dog", "horse", "cow", "cat"], | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# } | |||
Tasks.image_classification: ['scores', 'labels'], | |||
Tasks.image_tagging: ['scores', 'labels'], | |||
# object detection result for single sample | |||
# { | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ], | |||
# "labels": ["dog", "horse", "cow", "cat"], | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# } | |||
Tasks.object_detection: ['scores', 'labels', 'boxes'], | |||
# instance segmentation result for single sample | |||
# { | |||
# "masks": [ | |||
# np.array in bgr channel order | |||
# ], | |||
# "labels": ["dog", "horse", "cow", "cat"], | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# } | |||
Tasks.image_segmentation: ['scores', 'labels', 'boxes'], | |||
# image generation/editing/matting result for single sample | |||
# { | |||
# "output_png": np.array with shape(h, w, 4) | |||
# for matting or (h, w, 3) for general purpose | |||
# } | |||
Tasks.image_editing: ['output_png'], | |||
Tasks.image_matting: ['output_png'], | |||
Tasks.image_generation: ['output_png'], | |||
# pose estimation result for single sample | |||
# { | |||
# "poses": np.array with shape [num_pose, num_keypoint, 3], | |||
# each keypoint is a array [x, y, score] | |||
# "boxes": np.array with shape [num_pose, 4], each box is | |||
# [x1, y1, x2, y2] | |||
# } | |||
Tasks.pose_estimation: ['poses', 'boxes'], | |||
# ============ nlp tasks =================== | |||
# text classification result for single sample | |||
# { | |||
# "labels": ["happy", "sad", "calm", "angry"], | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# } | |||
Tasks.text_classification: ['scores', 'labels'], | |||
# text generation result for single sample | |||
# { | |||
# "text": "this is text generated by a model." | |||
# } | |||
Tasks.text_generation: ['text'], | |||
# ============ audio tasks =================== | |||
# ============ multi-modal tasks =================== | |||
# image caption result for single sample | |||
# { | |||
# "caption": "this is an image caption text." | |||
# } | |||
Tasks.image_captioning: ['caption'], | |||
# visual grounding result for single sample | |||
# { | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ], | |||
# "scores": [0.9, 0.1, 0.05, 0.05] | |||
# } | |||
Tasks.visual_grounding: ['boxes', 'scores'], | |||
# text_to_image result for a single sample | |||
# { | |||
# "image": np.ndarray with shape [height, width, 3] | |||
# } | |||
Tasks.text_to_image_synthesis: ['image'] | |||
} |
@@ -51,7 +51,7 @@ class Tasks(object): | |||
text_to_speech = 'text-to-speech' | |||
speech_signal_process = 'speech-signal-process' | |||
# multi-media | |||
# multi-modal tasks | |||
image_captioning = 'image-captioning' | |||
visual_grounding = 'visual-grounding' | |||
text_to_image_synthesis = 'text-to-image-synthesis' | |||
@@ -69,6 +69,7 @@ class Registry(object): | |||
f'{self._name}[{group_key}]') | |||
self._modules[group_key][module_name] = module_cls | |||
module_cls.group_key = group_key | |||
if module_name in self._modules[default_group]: | |||
if id(self._modules[default_group][module_name]) == id(module_cls): | |||
@@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase): | |||
CustomPipeline1() | |||
def test_custom(self): | |||
dummy_task = 'dummy-task' | |||
@PIPELINES.register_module( | |||
group_key=Tasks.image_tagging, module_name='custom-image') | |||
group_key=dummy_task, module_name='custom-image') | |||
class CustomImagePipeline(Pipeline): | |||
def __init__(self, | |||
@@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase): | |||
outputs['filename'] = inputs['url'] | |||
img = inputs['img'] | |||
new_image = img.resize((img.width // 2, img.height // 2)) | |||
outputs['resize_image'] = np.array(new_image) | |||
outputs['dummy_result'] = 'dummy_result' | |||
outputs['output_png'] = np.array(new_image) | |||
return outputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs | |||
self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||
add_default_pipeline_info(Tasks.image_tagging, 'custom-image') | |||
add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) | |||
pipe = pipeline(pipeline_name='custom-image') | |||
pipe2 = pipeline(Tasks.image_tagging) | |||
pipe2 = pipeline(dummy_task) | |||
self.assertTrue(type(pipe) is type(pipe2)) | |||
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||
'aliyuncs.com/data/test/images/image1.jpg' | |||
output = pipe(img_url) | |||
self.assertEqual(output['filename'], img_url) | |||
self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||
self.assertEqual(output['dummy_result'], 'dummy_result') | |||
self.assertEqual(output['output_png'].shape, (318, 512, 3)) | |||
outputs = pipe([img_url for i in range(4)]) | |||
self.assertEqual(len(outputs), 4) | |||
for out in outputs: | |||
self.assertEqual(out['filename'], img_url) | |||
self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||
self.assertEqual(out['dummy_result'], 'dummy_result') | |||
self.assertEqual(out['output_png'].shape, (318, 512, 3)) | |||
if __name__ == '__main__': | |||