Browse Source

[to #42362853] formalize the output of pipeline and make pipeline reusable

* format pipeline output and check it
* fix UT
* add docstr to clarify the difference between model.postprocess and pipeline.postprocess

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051405
master
wenmeng.zwm 3 years ago
parent
commit
4f7928bb6e
11 changed files with 203 additions and 57 deletions
  1. +2
    -1
      Makefile.docker
  2. +19
    -3
      docker/pytorch.dockerfile
  3. +13
    -5
      modelscope/models/base.py
  4. +17
    -0
      modelscope/models/nlp/sequence_classification_model.py
  5. +28
    -0
      modelscope/pipelines/base.py
  6. +16
    -37
      modelscope/pipelines/nlp/sequence_classification_pipeline.py
  7. +1
    -1
      modelscope/pipelines/nlp/text_generation_pipeline.py
  8. +98
    -0
      modelscope/pipelines/outputs.py
  9. +1
    -1
      modelscope/utils/constant.py
  10. +1
    -0
      modelscope/utils/registry.py
  11. +7
    -9
      tests/pipelines/test_base.py

+ 2
- 1
Makefile.docker View File

@@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE)
# CUDA_VERSION = 11.3 # CUDA_VERSION = 11.3
# CUDNN_VERSION = 8 # CUDNN_VERSION = 8
BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel




MODELSCOPE_VERSION = $(shell git describe --tags --always) MODELSCOPE_VERSION = $(shell git describe --tags --always)


+ 19
- 3
docker/pytorch.dockerfile View File

@@ -8,13 +8,29 @@
# For reference: # For reference:
# https://docs.docker.com/develop/develop-images/build_enhancements/ # https://docs.docker.com/develop/develop-images/build_enhancements/


#ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
#FROM ${BASE_IMAGE} as dev-base
# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04
# FROM ${BASE_IMAGE} as dev-base


FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base
# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime
# config pip source # config pip source
RUN mkdir /root/.pip RUN mkdir /root/.pip
COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf
COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list

# Install essential Ubuntu packages
RUN apt-get update &&\
apt-get install -y software-properties-common \
build-essential \
git \
wget \
vim \
curl \
zip \
zlib1g-dev \
unzip \
pkg-config


# install modelscope and its python env # install modelscope and its python env
WORKDIR /opt/modelscope WORKDIR /opt/modelscope


+ 13
- 5
modelscope/models/base.py View File

@@ -20,16 +20,24 @@ class Model(ABC):
self.model_dir = model_dir self.model_dir = model_dir


def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.post_process(self.forward(input))
return self.postprocess(self.forward(input))


@abstractmethod @abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
pass pass


def post_process(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
# model specific postprocess, implementation is optional
# will be called in Pipeline and evaluation loop(in the future)
def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
""" Model specific postprocess and convert model output to
standard model outputs.

Args:
inputs: input data

Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return input return input


@classmethod @classmethod


+ 17
- 0
modelscope/models/nlp/sequence_classification_model.py View File

@@ -1,5 +1,7 @@
import os
from typing import Any, Dict from typing import Any, Dict


import json
import numpy as np import numpy as np


from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
@@ -34,6 +36,11 @@ class BertForSequenceClassification(Model):
('token_type_ids', torch.LongTensor)], ('token_type_ids', torch.LongTensor)],
output_keys=['predictions', 'probabilities', 'logits']) output_keys=['predictions', 'probabilities', 'logits'])


self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model """return the result by the model


@@ -50,3 +57,13 @@ class BertForSequenceClassification(Model):
} }
""" """
return self.model.predict(input) return self.model.predict(input)

def postprocess(self, inputs: Dict[str, np.ndarray],
**kwargs) -> Dict[str, np.ndarray]:
# N x num_classes
probs = inputs['probabilities']
result = {
'probs': probs,
}

return result

+ 28
- 0
modelscope/pipelines/base.py View File

@@ -12,6 +12,7 @@ from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .outputs import TASK_OUTPUTS
from .util import is_model_name from .util import is_model_name


Tensor = Union['torch.Tensor', 'tf.Tensor'] Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -106,8 +107,25 @@ class Pipeline(ABC):
out = self.preprocess(input) out = self.preprocess(input)
out = self.forward(out) out = self.forward(out)
out = self.postprocess(out, **post_kwargs) out = self.postprocess(out, **post_kwargs)
self._check_output(out)
return out return out


def _check_output(self, input):
# this attribute is dynamically attached by registry
# when cls is registered in registry using task name
task_name = self.group_key
if task_name not in TASK_OUTPUTS:
logger.warning(f'task {task_name} output keys are missing')
return
output_keys = TASK_OUTPUTS[task_name]
missing_keys = []
for k in output_keys:
if k not in input:
missing_keys.append(k)
if len(missing_keys) > 0:
raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing')

def preprocess(self, inputs: Input) -> Dict[str, Any]: def preprocess(self, inputs: Input) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it """ Provide default implementation based on preprocess_cfg and user can reimplement it
""" """
@@ -125,4 +143,14 @@ class Pipeline(ABC):


@abstractmethod @abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess
code should be write here.

Args:
inputs: input data

Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
raise NotImplementedError('postprocess') raise NotImplementedError('postprocess')

+ 16
- 37
modelscope/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline):
second_sequence=None) second_sequence=None)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)


from easynlp.utils import io
self.label_path = os.path.join(sc_model.model_dir,
'label_mapping.json')
with io.open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {
idx: name
for name, idx in self.label_mapping.items()
}
assert hasattr(self.model, 'id2label'), \
'id2label map should be initalizaed in init function.'


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def postprocess(self,
inputs: Dict[str, Any],
topk: int = 5) -> Dict[str, str]:
"""process the prediction results """process the prediction results


Args: Args:
inputs (Dict[str, Any]): _description_
inputs (Dict[str, Any]): input data dict
topk (int): return topk classification result.


Returns: Returns:
Dict[str, str]: the prediction results Dict[str, str]: the prediction results
""" """
# NxC np.ndarray
probs = inputs['probs'][0]
num_classes = probs.shape[0]
topk = min(topk, num_classes)
top_indices = np.argpartition(probs, -topk)[-topk:]
cls_ids = top_indices[np.argsort(probs[top_indices])]
probs = probs[cls_ids].tolist()


probs = inputs['probabilities']
logits = inputs['logits']
predictions = np.argsort(-probs, axis=-1)
preds = predictions[0]
b = 0
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})
new_results = list()
new_results.append({
'id':
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
'output':
new_result,
'predictions':
new_result[0]['pred'],
'probabilities':
','.join([str(t) for t in inputs['probabilities'][b]]),
'logits':
','.join([str(t) for t in inputs['logits'][b]])
})
cls_names = [self.model.id2label[cid] for cid in cls_ids]


return new_results[0]
return {'scores': probs, 'labels': cls_names}

+ 1
- 1
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline):
'').split('[SEP]')[0].replace('[CLS]', '').split('[SEP]')[0].replace('[CLS]',
'').replace('[SEP]', '').replace('[SEP]',
'').replace('[UNK]', '') '').replace('[UNK]', '')
return {'pred_string': pred_string}
return {'text': pred_string}

+ 98
- 0
modelscope/pipelines/outputs.py View File

@@ -0,0 +1,98 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from modelscope.utils.constant import Tasks

TASK_OUTPUTS = {

# ============ vision tasks ===================

# image classification result for single sample
# {
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.image_classification: ['scores', 'labels'],
Tasks.image_tagging: ['scores', 'labels'],

# object detection result for single sample
# {
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ],
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.object_detection: ['scores', 'labels', 'boxes'],

# instance segmentation result for single sample
# {
# "masks": [
# np.array in bgr channel order
# ],
# "labels": ["dog", "horse", "cow", "cat"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.image_segmentation: ['scores', 'labels', 'boxes'],

# image generation/editing/matting result for single sample
# {
# "output_png": np.array with shape(h, w, 4)
# for matting or (h, w, 3) for general purpose
# }
Tasks.image_editing: ['output_png'],
Tasks.image_matting: ['output_png'],
Tasks.image_generation: ['output_png'],

# pose estimation result for single sample
# {
# "poses": np.array with shape [num_pose, num_keypoint, 3],
# each keypoint is a array [x, y, score]
# "boxes": np.array with shape [num_pose, 4], each box is
# [x1, y1, x2, y2]
# }
Tasks.pose_estimation: ['poses', 'boxes'],

# ============ nlp tasks ===================

# text classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.text_classification: ['scores', 'labels'],

# text generation result for single sample
# {
# "text": "this is text generated by a model."
# }
Tasks.text_generation: ['text'],

# ============ audio tasks ===================

# ============ multi-modal tasks ===================

# image caption result for single sample
# {
# "caption": "this is an image caption text."
# }
Tasks.image_captioning: ['caption'],

# visual grounding result for single sample
# {
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.visual_grounding: ['boxes', 'scores'],

# text_to_image result for a single sample
# {
# "image": np.ndarray with shape [height, width, 3]
# }
Tasks.text_to_image_synthesis: ['image']
}

+ 1
- 1
modelscope/utils/constant.py View File

@@ -51,7 +51,7 @@ class Tasks(object):
text_to_speech = 'text-to-speech' text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process' speech_signal_process = 'speech-signal-process'


# multi-media
# multi-modal tasks
image_captioning = 'image-captioning' image_captioning = 'image-captioning'
visual_grounding = 'visual-grounding' visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis' text_to_image_synthesis = 'text-to-image-synthesis'


+ 1
- 0
modelscope/utils/registry.py View File

@@ -69,6 +69,7 @@ class Registry(object):
f'{self._name}[{group_key}]') f'{self._name}[{group_key}]')


self._modules[group_key][module_name] = module_cls self._modules[group_key][module_name] = module_cls
module_cls.group_key = group_key


if module_name in self._modules[default_group]: if module_name in self._modules[default_group]:
if id(self._modules[default_group][module_name]) == id(module_cls): if id(self._modules[default_group][module_name]) == id(module_cls):


+ 7
- 9
tests/pipelines/test_base.py View File

@@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase):
CustomPipeline1() CustomPipeline1()


def test_custom(self): def test_custom(self):
dummy_task = 'dummy-task'


@PIPELINES.register_module( @PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='custom-image')
group_key=dummy_task, module_name='custom-image')
class CustomImagePipeline(Pipeline): class CustomImagePipeline(Pipeline):


def __init__(self, def __init__(self,
@@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase):
outputs['filename'] = inputs['url'] outputs['filename'] = inputs['url']
img = inputs['img'] img = inputs['img']
new_image = img.resize((img.width // 2, img.height // 2)) new_image = img.resize((img.width // 2, img.height // 2))
outputs['resize_image'] = np.array(new_image)
outputs['dummy_result'] = 'dummy_result'
outputs['output_png'] = np.array(new_image)
return outputs return outputs


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs return inputs


self.assertTrue('custom-image' in PIPELINES.modules[default_group]) self.assertTrue('custom-image' in PIPELINES.modules[default_group])
add_default_pipeline_info(Tasks.image_tagging, 'custom-image')
add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
pipe = pipeline(pipeline_name='custom-image') pipe = pipeline(pipeline_name='custom-image')
pipe2 = pipeline(Tasks.image_tagging)
pipe2 = pipeline(dummy_task)
self.assertTrue(type(pipe) is type(pipe2)) self.assertTrue(type(pipe) is type(pipe2))


img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
'aliyuncs.com/data/test/images/image1.jpg' 'aliyuncs.com/data/test/images/image1.jpg'
output = pipe(img_url) output = pipe(img_url)
self.assertEqual(output['filename'], img_url) self.assertEqual(output['filename'], img_url)
self.assertEqual(output['resize_image'].shape, (318, 512, 3))
self.assertEqual(output['dummy_result'], 'dummy_result')
self.assertEqual(output['output_png'].shape, (318, 512, 3))


outputs = pipe([img_url for i in range(4)]) outputs = pipe([img_url for i in range(4)])
self.assertEqual(len(outputs), 4) self.assertEqual(len(outputs), 4)
for out in outputs: for out in outputs:
self.assertEqual(out['filename'], img_url) self.assertEqual(out['filename'], img_url)
self.assertEqual(out['resize_image'].shape, (318, 512, 3))
self.assertEqual(out['dummy_result'], 'dummy_result')
self.assertEqual(out['output_png'].shape, (318, 512, 3))




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save