diff --git a/Makefile.docker b/Makefile.docker index bbac840e..97400318 100644 --- a/Makefile.docker +++ b/Makefile.docker @@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) # CUDA_VERSION = 11.3 # CUDNN_VERSION = 8 BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 -BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel MODELSCOPE_VERSION = $(shell git describe --tags --always) diff --git a/docker/pytorch.dockerfile b/docker/pytorch.dockerfile index 73c35af1..4862cab6 100644 --- a/docker/pytorch.dockerfile +++ b/docker/pytorch.dockerfile @@ -8,13 +8,29 @@ # For reference: # https://docs.docker.com/develop/develop-images/build_enhancements/ -#ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 -#FROM ${BASE_IMAGE} as dev-base +# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# FROM ${BASE_IMAGE} as dev-base -FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base +# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base +FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel +# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime # config pip source RUN mkdir /root/.pip COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf +COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list + +# Install essential Ubuntu packages +RUN apt-get update &&\ + apt-get install -y software-properties-common \ + build-essential \ + git \ + wget \ + vim \ + curl \ + zip \ + zlib1g-dev \ + unzip \ + pkg-config # install modelscope and its python env WORKDIR /opt/modelscope diff --git a/modelscope/models/base.py b/modelscope/models/base.py index 3e361f91..88b1e3b0 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -20,16 +20,24 @@ class Model(ABC): self.model_dir = model_dir def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - return self.post_process(self.forward(input)) + return self.postprocess(self.forward(input)) @abstractmethod def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: pass - def post_process(self, input: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: - # model specific postprocess, implementation is optional - # will be called in Pipeline and evaluation loop(in the future) + def postprocess(self, input: Dict[str, Tensor], + **kwargs) -> Dict[str, Tensor]: + """ Model specific postprocess and convert model output to + standard model outputs. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ return input @classmethod diff --git a/modelscope/models/nlp/sequence_classification_model.py b/modelscope/models/nlp/sequence_classification_model.py index 6ced7a4e..a3cc4b68 100644 --- a/modelscope/models/nlp/sequence_classification_model.py +++ b/modelscope/models/nlp/sequence_classification_model.py @@ -1,5 +1,7 @@ +import os from typing import Any, Dict +import json import numpy as np from modelscope.utils.constant import Tasks @@ -34,6 +36,11 @@ class BertForSequenceClassification(Model): ('token_type_ids', torch.LongTensor)], output_keys=['predictions', 'probabilities', 'logits']) + self.label_path = os.path.join(self.model_dir, 'label_mapping.json') + with open(self.label_path) as f: + self.label_mapping = json.load(f) + self.id2label = {idx: name for name, idx in self.label_mapping.items()} + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: """return the result by the model @@ -50,3 +57,13 @@ class BertForSequenceClassification(Model): } """ return self.model.predict(input) + + def postprocess(self, inputs: Dict[str, np.ndarray], + **kwargs) -> Dict[str, np.ndarray]: + # N x num_classes + probs = inputs['probabilities'] + result = { + 'probs': probs, + } + + return result diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index c69afdca..1da65213 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -12,6 +12,7 @@ from modelscope.pydatasets import PyDataset from modelscope.utils.config import Config from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.logger import get_logger +from .outputs import TASK_OUTPUTS from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -106,8 +107,25 @@ class Pipeline(ABC): out = self.preprocess(input) out = self.forward(out) out = self.postprocess(out, **post_kwargs) + self._check_output(out) return out + def _check_output(self, input): + # this attribute is dynamically attached by registry + # when cls is registered in registry using task name + task_name = self.group_key + if task_name not in TASK_OUTPUTS: + logger.warning(f'task {task_name} output keys are missing') + return + output_keys = TASK_OUTPUTS[task_name] + missing_keys = [] + for k in output_keys: + if k not in input: + missing_keys.append(k) + if len(missing_keys) > 0: + raise ValueError(f'expected output keys are {output_keys}, ' + f'those {missing_keys} are missing') + def preprocess(self, inputs: Input) -> Dict[str, Any]: """ Provide default implementation based on preprocess_cfg and user can reimplement it """ @@ -125,4 +143,14 @@ class Pipeline(ABC): @abstractmethod def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ If current pipeline support model reuse, common postprocess + code should be write here. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ raise NotImplementedError('postprocess') diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py index 5a14f136..9d2e4273 100644 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sequence_classification_pipeline.py @@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline): second_sequence=None) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) - from easynlp.utils import io - self.label_path = os.path.join(sc_model.model_dir, - 'label_mapping.json') - with io.open(self.label_path) as f: - self.label_mapping = json.load(f) - self.label_id_to_name = { - idx: name - for name, idx in self.label_mapping.items() - } + assert hasattr(self.model, 'id2label'), \ + 'id2label map should be initalizaed in init function.' - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def postprocess(self, + inputs: Dict[str, Any], + topk: int = 5) -> Dict[str, str]: """process the prediction results Args: - inputs (Dict[str, Any]): _description_ + inputs (Dict[str, Any]): input data dict + topk (int): return topk classification result. Returns: Dict[str, str]: the prediction results """ + # NxC np.ndarray + probs = inputs['probs'][0] + num_classes = probs.shape[0] + topk = min(topk, num_classes) + top_indices = np.argpartition(probs, -topk)[-topk:] + cls_ids = top_indices[np.argsort(probs[top_indices])] + probs = probs[cls_ids].tolist() - probs = inputs['probabilities'] - logits = inputs['logits'] - predictions = np.argsort(-probs, axis=-1) - preds = predictions[0] - b = 0 - new_result = list() - for pred in preds: - new_result.append({ - 'pred': self.label_id_to_name[pred], - 'prob': float(probs[b][pred]), - 'logit': float(logits[b][pred]) - }) - new_results = list() - new_results.append({ - 'id': - inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), - 'output': - new_result, - 'predictions': - new_result[0]['pred'], - 'probabilities': - ','.join([str(t) for t in inputs['probabilities'][b]]), - 'logits': - ','.join([str(t) for t in inputs['logits'][b]]) - }) + cls_names = [self.model.id2label[cid] for cid in cls_ids] - return new_results[0] + return {'scores': probs, 'labels': cls_names} diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 7ad2b67f..ea30a115 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline): '').split('[SEP]')[0].replace('[CLS]', '').replace('[SEP]', '').replace('[UNK]', '') - return {'pred_string': pred_string} + return {'text': pred_string} diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py new file mode 100644 index 00000000..1389abd3 --- /dev/null +++ b/modelscope/pipelines/outputs.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.constant import Tasks + +TASK_OUTPUTS = { + + # ============ vision tasks =================== + + # image classification result for single sample + # { + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.image_classification: ['scores', 'labels'], + Tasks.image_tagging: ['scores', 'labels'], + + # object detection result for single sample + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.object_detection: ['scores', 'labels', 'boxes'], + + # instance segmentation result for single sample + # { + # "masks": [ + # np.array in bgr channel order + # ], + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.image_segmentation: ['scores', 'labels', 'boxes'], + + # image generation/editing/matting result for single sample + # { + # "output_png": np.array with shape(h, w, 4) + # for matting or (h, w, 3) for general purpose + # } + Tasks.image_editing: ['output_png'], + Tasks.image_matting: ['output_png'], + Tasks.image_generation: ['output_png'], + + # pose estimation result for single sample + # { + # "poses": np.array with shape [num_pose, num_keypoint, 3], + # each keypoint is a array [x, y, score] + # "boxes": np.array with shape [num_pose, 4], each box is + # [x1, y1, x2, y2] + # } + Tasks.pose_estimation: ['poses', 'boxes'], + + # ============ nlp tasks =================== + + # text classification result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.text_classification: ['scores', 'labels'], + + # text generation result for single sample + # { + # "text": "this is text generated by a model." + # } + Tasks.text_generation: ['text'], + + # ============ audio tasks =================== + + # ============ multi-modal tasks =================== + + # image caption result for single sample + # { + # "caption": "this is an image caption text." + # } + Tasks.image_captioning: ['caption'], + + # visual grounding result for single sample + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.visual_grounding: ['boxes', 'scores'], + + # text_to_image result for a single sample + # { + # "image": np.ndarray with shape [height, width, 3] + # } + Tasks.text_to_image_synthesis: ['image'] +} diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2fcfee95..6ce835c5 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -51,7 +51,7 @@ class Tasks(object): text_to_speech = 'text-to-speech' speech_signal_process = 'speech-signal-process' - # multi-media + # multi-modal tasks image_captioning = 'image-captioning' visual_grounding = 'visual-grounding' text_to_image_synthesis = 'text-to-image-synthesis' diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 888564c7..319e54cb 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -69,6 +69,7 @@ class Registry(object): f'{self._name}[{group_key}]') self._modules[group_key][module_name] = module_cls + module_cls.group_key = group_key if module_name in self._modules[default_group]: if id(self._modules[default_group][module_name]) == id(module_cls): diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index 14f646a9..73aebfdf 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase): CustomPipeline1() def test_custom(self): + dummy_task = 'dummy-task' @PIPELINES.register_module( - group_key=Tasks.image_tagging, module_name='custom-image') + group_key=dummy_task, module_name='custom-image') class CustomImagePipeline(Pipeline): def __init__(self, @@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase): outputs['filename'] = inputs['url'] img = inputs['img'] new_image = img.resize((img.width // 2, img.height // 2)) - outputs['resize_image'] = np.array(new_image) - outputs['dummy_result'] = 'dummy_result' + outputs['output_png'] = np.array(new_image) return outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs self.assertTrue('custom-image' in PIPELINES.modules[default_group]) - add_default_pipeline_info(Tasks.image_tagging, 'custom-image') + add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) pipe = pipeline(pipeline_name='custom-image') - pipe2 = pipeline(Tasks.image_tagging) + pipe2 = pipeline(dummy_task) self.assertTrue(type(pipe) is type(pipe2)) img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ 'aliyuncs.com/data/test/images/image1.jpg' output = pipe(img_url) self.assertEqual(output['filename'], img_url) - self.assertEqual(output['resize_image'].shape, (318, 512, 3)) - self.assertEqual(output['dummy_result'], 'dummy_result') + self.assertEqual(output['output_png'].shape, (318, 512, 3)) outputs = pipe([img_url for i in range(4)]) self.assertEqual(len(outputs), 4) for out in outputs: self.assertEqual(out['filename'], img_url) - self.assertEqual(out['resize_image'].shape, (318, 512, 3)) - self.assertEqual(out['dummy_result'], 'dummy_result') + self.assertEqual(out['output_png'].shape, (318, 512, 3)) if __name__ == '__main__':