Browse Source

[to #41401401] add preprocessor, model and pipeline

* add preprocessor module
 * add model base and builder
 * update task constant
 * add load image preprocessor and its dependency
 * add pipeline interface and UT covered
 * support default pipeline for task
 * add image matting pipeline
 * refine nlp tokenize interface
 * add nlp pipeline 
 * fix UT failed
 * add test for Compose

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8769235

* add preprocessor module

* add test for Compose

* fix citest error

* fix abs class error

* add model base and builder

* update task constant

* add load image preprocessor and its dependency

* add pipeline interface and UT covered

* support default pipeline for task

* refine models and pipeline interface

* add pipeline folder structure

* add image matting pipeline

* refine nlp tokenize interface

* add nlp pipeline 

1.add preprossor model pipeline for nlp text classification
2. add corresponding test

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8757371

* new nlp pipeline

* format pre-commit code

* update easynlp pipeline

* update model_name for easynlp pipeline; add test for maas_lib/utils/typeassert.py

* update test_typeassert.py

* refactor code

1. rename typeassert to type_assert
2. use lazy import to make easynlp dependency optional
3. refine image matting UT

* fix linter test failed

* update requirements.txt

* fix UT failed

* fix citest script to update requirements
master
wenmeng.zwm huangjun.hj 3 years ago
parent
commit
5e469008fd
39 changed files with 1053 additions and 10 deletions
  1. +1
    -1
      .dev_scripts/citest.sh
  2. +1
    -0
      maas_lib/fileio/__init__.py
  3. +1
    -0
      maas_lib/fileio/file.py
  4. +4
    -0
      maas_lib/models/__init__.py
  5. +29
    -0
      maas_lib/models/base.py
  6. +22
    -0
      maas_lib/models/builder.py
  7. +1
    -0
      maas_lib/models/nlp/__init__.py
  8. +62
    -0
      maas_lib/models/nlp/sequence_classification_model.py
  9. +6
    -0
      maas_lib/pipelines/__init__.py
  10. +0
    -0
      maas_lib/pipelines/audio/__file__.py
  11. +63
    -0
      maas_lib/pipelines/base.py
  12. +65
    -0
      maas_lib/pipelines/builder.py
  13. +1
    -0
      maas_lib/pipelines/cv/__init__.py
  14. +67
    -0
      maas_lib/pipelines/cv/image_matting.py
  15. +0
    -0
      maas_lib/pipelines/multi_modal/__init__.py
  16. +1
    -0
      maas_lib/pipelines/nlp/__init__.py
  17. +77
    -0
      maas_lib/pipelines/nlp/sequence_classification_pipeline.py
  18. +7
    -0
      maas_lib/preprocessors/__init__.py
  19. +14
    -0
      maas_lib/preprocessors/base.py
  20. +22
    -0
      maas_lib/preprocessors/builder.py
  21. +54
    -0
      maas_lib/preprocessors/common.py
  22. +70
    -0
      maas_lib/preprocessors/image.py
  23. +91
    -0
      maas_lib/preprocessors/nlp.py
  24. +31
    -1
      maas_lib/utils/constant.py
  25. +22
    -4
      maas_lib/utils/registry.py
  26. +50
    -0
      maas_lib/utils/type_assert.py
  27. +1
    -0
      requirements.txt
  28. +5
    -0
      requirements/pipeline.txt
  29. +3
    -0
      requirements/runtime.txt
  30. +1
    -1
      setup.cfg
  31. +0
    -0
      tests/pipelines/__init__.py
  32. +98
    -0
      tests/pipelines/test_base.py
  33. +32
    -0
      tests/pipelines/test_image_matting.py
  34. +48
    -0
      tests/pipelines/test_text_classification.py
  35. +0
    -0
      tests/preprocessors/__init__.py
  36. +39
    -0
      tests/preprocessors/test_common.py
  37. +37
    -0
      tests/preprocessors/test_nlp.py
  38. +5
    -3
      tests/utils/test_registry.py
  39. +22
    -0
      tests/utils/test_type_assert.py

+ 1
- 1
.dev_scripts/citest.sh View File

@@ -1,4 +1,4 @@
pip install -r requirements/runtime.txt
pip install -r requirements.txt
pip install -r requirements/tests.txt




+ 1
- 0
maas_lib/fileio/__init__.py View File

@@ -1 +1,2 @@
from .file import File
from .io import dump, dumps, load

+ 1
- 0
maas_lib/fileio/file.py View File

@@ -123,6 +123,7 @@ class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""

def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
r = requests.get(url)
r.raise_for_status()
return r.content


+ 4
- 0
maas_lib/models/__init__.py View File

@@ -0,0 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .base import Model
from .builder import MODELS

+ 29
- 0
maas_lib/models/base.py View File

@@ -0,0 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union

Tensor = Union['torch.Tensor', 'tf.Tensor']


class Model(ABC):

def __init__(self, *args, **kwargs):
pass

def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.post_process(self.forward(input))

@abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
pass

def post_process(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
# model specific postprocess, implementation is optional
# will be called in Pipeline and evaluation loop(in the future)
return input

@classmethod
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
raise NotImplementedError('from_preatrained has not been implemented')

+ 22
- 0
maas_lib/models/builder.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from maas_lib.utils.config import ConfigDict
from maas_lib.utils.constant import Tasks
from maas_lib.utils.registry import Registry, build_from_cfg

MODELS = Registry('models')


def build_model(cfg: ConfigDict,
task_name: str = None,
default_args: dict = None):
""" build model given model config dict

Args:
cfg (:obj:`ConfigDict`): config dict for model object.
task_name (str, optional): task name, refer to
:obj:`Tasks` for more details
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, MODELS, group_key=task_name, default_args=default_args)

+ 1
- 0
maas_lib/models/nlp/__init__.py View File

@@ -0,0 +1 @@
from .sequence_classification_model import * # noqa F403

+ 62
- 0
maas_lib/models/nlp/sequence_classification_model.py View File

@@ -0,0 +1,62 @@
from typing import Any, Dict, Optional, Union

import numpy as np
import torch

from maas_lib.utils.constant import Tasks
from ..base import Model
from ..builder import MODELS

__all__ = ['SequenceClassificationModel']


@MODELS.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
class SequenceClassificationModel(Model):

def __init__(self,
model_dir: str,
model_cls: Optional[Any] = None,
*args,
**kwargs):
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
# Predictor.__init__(self, *args, **kwargs)
"""initilize the sequence classification model from the `model_dir` path

Args:
model_dir (str): the model path
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None
"""

super().__init__(model_dir, model_cls, *args, **kwargs)

from easynlp.appzoo import SequenceClassification
from easynlp.core.predictor import get_model_predictor
self.model_dir = model_dir
model_cls = SequenceClassification if not model_cls else model_cls
self.model = get_model_predictor(
model_dir=model_dir,
model_cls=model_cls,
input_keys=[('input_ids', torch.LongTensor),
('attention_mask', torch.LongTensor),
('token_type_ids', torch.LongTensor)],
output_keys=['predictions', 'probabilities', 'logits'])

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
return self.model.predict(input)
...

+ 6
- 0
maas_lib/pipelines/__init__.py View File

@@ -0,0 +1,6 @@
from .audio import * # noqa F403
from .base import Pipeline
from .builder import pipeline
from .cv import * # noqa F403
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403

+ 0
- 0
maas_lib/pipelines/audio/__file__.py View File


+ 63
- 0
maas_lib/pipelines/base.py View File

@@ -0,0 +1,63 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union

from maas_lib.models import Model
from maas_lib.preprocessors import Preprocessor

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray']

output_keys = [
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key


class Pipeline(ABC):

def __init__(self,
config_file: str = None,
model: Model = None,
preprocessor: Preprocessor = None,
**kwargs):
self.model = model
self.preprocessor = preprocessor

def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Dict[str, Any]:
# moodel provider should leave it as it is
# maas library developer will handle this function

# simple show case, need to support iterator type for both tensorflow and pytorch
# input_dict = self._handle_input(input)
if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))
else:
output = self._process_single(input, *args, **post_kwargs)
return output

def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
return out

def preprocess(self, inputs: Input) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it

"""
assert self.preprocessor is not None, 'preprocess method should be implemented'
return self.preprocessor(inputs)

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
assert self.model is not None, 'forward method should be implemented'
return self.model(inputs)

@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
raise NotImplementedError('postprocess')

+ 65
- 0
maas_lib/pipelines/builder.py View File

@@ -0,0 +1,65 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Union

from maas_lib.models.base import Model
from maas_lib.utils.config import ConfigDict
from maas_lib.utils.constant import Tasks
from maas_lib.utils.registry import Registry, build_from_cfg
from .base import Pipeline

PIPELINES = Registry('pipelines')


def build_pipeline(cfg: ConfigDict,
task_name: str = None,
default_args: dict = None):
""" build pipeline given model config dict

Args:
cfg (:obj:`ConfigDict`): config dict for model object.
task_name (str, optional): task name, refer to
:obj:`Tasks` for more details
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, PIPELINES, group_key=task_name, default_args=default_args)


def pipeline(task: str = None,
model: Union[str, Model] = None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: int = -1,
**kwargs) -> Pipeline:
""" Factory method to build a obj:`Pipeline`.


Args:
task (str): Task name defining which pipeline will be returned.
model (str or obj:`Model`): model name or model object.
config_file (str, optional): path to config file.
pipeline_name (str, optional): pipeline class name or alias name.
framework (str, optional): framework type.
device (int, optional): which device is used to do inference.

Return:
pipeline (obj:`Pipeline`): pipeline object for certain task.

Examples:
```python
>>> p = pipeline('image-classification')
>>> p = pipeline('text-classification', model='distilbert-base-uncased')
>>> # Using model object
>>> resnet = Model.from_pretrained('Resnet')
>>> p = pipeline('image-classification', model=resnet)
"""
if task is not None and model is None and pipeline_name is None:
# get default pipeline for this task
assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}'
pipeline_name = list(PIPELINES.modules[task].keys())[0]

if pipeline_name is not None:
cfg = dict(type=pipeline_name, **kwargs)
return build_pipeline(cfg, task_name=task)

+ 1
- 0
maas_lib/pipelines/cv/__init__.py View File

@@ -0,0 +1 @@
from .image_matting import ImageMatting

+ 67
- 0
maas_lib/pipelines/cv/image_matting.py View File

@@ -0,0 +1,67 @@
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
import PIL
import tensorflow as tf
from cv2 import COLOR_GRAY2RGB

from maas_lib.pipelines.base import Input
from maas_lib.preprocessors import load_image
from maas_lib.utils.constant import Tasks
from maas_lib.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

if tf.__version__ >= '2.0':
tf = tf.compat.v1

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_matting, module_name=Tasks.image_matting)
class ImageMatting(Pipeline):

def __init__(self, model_path: str):
super().__init__()

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
with self._session.as_default():
logger.info(f'loading model from {model_path}')
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
self.output = self._session.graph.get_tensor_by_name(
'output_png:0')
self.input_name = 'input_image:0'
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = img.astype(np.float)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with self._session.as_default():
feed_dict = {self.input_name: input['img']}
output_png = self._session.run(self.output, feed_dict=feed_dict)
output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA)
return {'output_png': output_png}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 0
- 0
maas_lib/pipelines/multi_modal/__init__.py View File


+ 1
- 0
maas_lib/pipelines/nlp/__init__.py View File

@@ -0,0 +1 @@
from .sequence_classification_pipeline import * # noqa F403

+ 77
- 0
maas_lib/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -0,0 +1,77 @@
import os
import uuid
from typing import Any, Dict

import json
import numpy as np

from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES

__all__ = ['SequenceClassificationPipeline']


@PIPELINES.register_module(
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
class SequenceClassificationPipeline(Pipeline):

def __init__(self, model: SequenceClassificationModel,
preprocessor: SequenceClassificationPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)

from easynlp.utils import io
self.label_path = os.path.join(model.model_dir, 'label_mapping.json')
with io.open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {
idx: name
for name, idx in self.label_mapping.items()
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the predict results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the predict results
"""

probs = inputs['probabilities']
logits = inputs['logits']
predictions = np.argsort(-probs, axis=-1)
preds = predictions[0]
b = 0
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})
new_results = list()
new_results.append({
'id':
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
'output':
new_result,
'predictions':
new_result[0]['pred'],
'probabilities':
','.join([str(t) for t in inputs['probabilities'][b]]),
'logits':
','.join([str(t) for t in inputs['logits'][b]])
})

return new_results[0]

+ 7
- 0
maas_lib/preprocessors/__init__.py View File

@@ -0,0 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .base import Preprocessor
from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .nlp import * # noqa F403

+ 14
- 0
maas_lib/preprocessors/base.py View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from abc import ABC, abstractmethod
from typing import Any, Dict


class Preprocessor(ABC):

def __init__(self, *args, **kwargs):
pass

@abstractmethod
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
pass

+ 22
- 0
maas_lib/preprocessors/builder.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from maas_lib.utils.config import ConfigDict
from maas_lib.utils.constant import Fields
from maas_lib.utils.registry import Registry, build_from_cfg

PREPROCESSORS = Registry('preprocessors')


def build_preprocessor(cfg: ConfigDict,
field_name: str = None,
default_args: dict = None):
""" build preprocesor given model config dict

Args:
cfg (:obj:`ConfigDict`): config dict for model object.
field_name (str, optional): application field name, refer to
:obj:`Fields` for more details
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, PREPROCESSORS, group_key=field_name, default_args=default_args)

+ 54
- 0
maas_lib/preprocessors/common.py View File

@@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import time
from collections.abc import Sequence

from .builder import PREPROCESSORS, build_preprocessor


@PREPROCESSORS.register_module()
class Compose(object):
"""Compose a data pipeline with a sequence of transforms.
Args:
transforms (list[dict | callable]):
Either config dicts of transforms or transform objects.
profiling (bool, optional): If set True, will profile and
print preprocess time for each step.
"""

def __init__(self, transforms, field_name=None, profiling=False):
assert isinstance(transforms, Sequence)
self.profiling = profiling
self.transforms = []
self.field_name = field_name
for transform in transforms:
if isinstance(transform, dict):
if self.field_name is None:
transform = build_preprocessor(transform, field_name)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict, but got'
f' {type(transform)}')

def __call__(self, data):
for t in self.transforms:
if self.profiling:
start = time.time()

data = t(data)

if self.profiling:
print(f'{t} time {time.time()-start}')

if data is None:
return None
return data

def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += f'\n {t}'
format_string += '\n)'
return format_string

+ 70
- 0
maas_lib/preprocessors/image.py View File

@@ -0,0 +1,70 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
from typing import Dict, Union

from PIL import Image, ImageOps

from maas_lib.fileio import File
from maas_lib.utils.constant import Fields
from .builder import PREPROCESSORS


@PREPROCESSORS.register_module(Fields.image)
class LoadImage:
"""Load an image from file or url.
Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`.
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""

def __init__(self, mode='rgb'):
self.mode = mode.upper()

def __call__(self, input: Union[str, Dict[str, str]]):
"""Call functions to load image and get image meta information.
Args:
input (str or dict): input image path or input dict with
a key `filename`.
Returns:
dict: The dict contains loaded image.
"""
if isinstance(input, dict):
image_path_or_url = input['filename']
else:
image_path_or_url = input

bytes = File.read(image_path_or_url)
# TODO @wenmeng.zwm add opencv decode as optional
# we should also look at the input format which is the most commonly
# used in Mind' image related models
with io.BytesIO(bytes) as infile:
img = Image.open(infile)
img = ImageOps.exif_transpose(img)
img = img.convert(self.mode)

results = {
'filename': image_path_or_url,
'img': img,
'img_shape': (img.size[1], img.size[0], 3),
'img_field': 'img',
}
return results

def __repr__(self):
repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})')
return repr_str


def load_image(image_path_or_url: str) -> Image:
""" simple interface to load an image from file or url

Args:
image_path_or_url (str): image file path or http url
"""
loader = LoadImage()
return loader(image_path_or_url)['img']

+ 91
- 0
maas_lib/preprocessors/nlp.py View File

@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import uuid
from typing import Any, Dict, Union

from transformers import AutoTokenizer

from maas_lib.utils.constant import Fields, InputFields
from maas_lib.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['Tokenize', 'SequenceClassificationPreprocessor']


@PREPROCESSORS.register_module(Fields.nlp)
class Tokenize(Preprocessor):

def __init__(self, tokenizer_name) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(data, str):
data = {InputFields.text: data}
token_dict = self._tokenizer(data[InputFields.text])
data.update(token_dict)
return data


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-sentiment-analysis')
class SequenceClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from easynlp.modelzoo import AutoTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': []
}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
text_b = new_data.get(self.second_sequence, None)
feature = self.tokenizer(
text_a,
text_b,
padding='max_length',
truncation=True,
max_length=max_seq_length)

rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])

return rst

+ 31
- 1
maas_lib/utils/constant.py View File

@@ -6,6 +6,7 @@ class Fields(object):
"""
image = 'image'
video = 'video'
cv = 'cv'
nlp = 'nlp'
audio = 'audio'
multi_modal = 'multi_modal'
@@ -18,12 +19,41 @@ class Tasks(object):
This should be used to register models, pipelines, trainers.
"""
# vision tasks
image_to_text = 'image-to-text'
pose_estimation = 'pose-estimation'
image_classfication = 'image-classification'
image_tagging = 'image-tagging'
object_detection = 'object-detection'
image_segmentation = 'image-segmentation'
image_editing = 'image-editing'
image_generation = 'image-generation'
image_matting = 'image-matting'

# nlp tasks
sentiment_analysis = 'sentiment-analysis'
fill_mask = 'fill-mask'
text_classification = 'text-classification'
relation_extraction = 'relation-extraction'
zero_shot = 'zero-shot'
translation = 'translation'
token_classificatio = 'token-classification'
conversational = 'conversational'
text_generation = 'text-generation'
table_question_answ = 'table-question-answering'
feature_extraction = 'feature-extraction'
sentence_similarity = 'sentence-similarity'
fill_mask = 'fill-mask '
summarization = 'summarization'
question_answering = 'question-answering'

# audio tasks
auto_speech_recognition = 'auto-speech-recognition'
text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process'

# multi-media
image_captioning = 'image-captioning'
visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis'


class InputFields(object):


+ 22
- 4
maas_lib/utils/registry.py View File

@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import inspect
from email.policy import default

from maas_lib.utils.logger import get_logger

@@ -15,10 +17,10 @@ class Registry(object):

def __init__(self, name: str):
self._name = name
self._modules = dict()
self._modules = {default_group: {}}

def __repr__(self):
format_str = self.__class__.__name__ + f'({self._name})\n'
format_str = self.__class__.__name__ + f' ({self._name})\n'
for group_name, group in self._modules.items():
format_str += f'group_name={group_name}, '\
f'modules={list(group.keys())}\n'
@@ -64,11 +66,24 @@ class Registry(object):
module_name = module_cls.__name__

if module_name in self._modules[group_key]:
raise KeyError(f'{module_name} is already registered in'
raise KeyError(f'{module_name} is already registered in '
f'{self._name}[{group_key}]')

self._modules[group_key][module_name] = module_cls

if module_name in self._modules[default_group]:
if id(self._modules[default_group][module_name]) == id(module_cls):
return
else:
logger.warning(f'{module_name} is already registered in '
f'{self._name}[{default_group}] and will '
'be overwritten')
logger.warning(f'{self._modules[default_group][module_name]}'
'to {module_cls}')
# also register module in the default group for faster access
# only by module name
self._modules[default_group][module_name] = module_cls

def register_module(self,
group_key: str = default_group,
module_name: str = None,
@@ -165,12 +180,15 @@ def build_from_cfg(cfg,
for name, value in default_args.items():
args.setdefault(name, value)

if group_key is None:
group_key = default_group

obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type, group_key=group_key)
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name}'
f'registry group {group_key}')
f' registry group {group_key}')
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:


+ 50
- 0
maas_lib/utils/type_assert.py View File

@@ -0,0 +1,50 @@
from functools import wraps
from inspect import signature


def type_assert(*ty_args, **ty_kwargs):
"""a decorator which is used to check the types of arguments in a function or class
Examples:
>>> @type_assert(str)
... def main(a: str, b: list):
... print(a, b)
>>> main(1)
Argument a must be a str

>>> @type_assert(str, (int, str))
... def main(a: str, b: int | str):
... print(a, b)
>>> main('1', [1])
Argument b must be (<class 'int'>, <class 'str'>)

>>> @type_assert(str, (int, str))
... class A:
... def __init__(self, a: str, b: int | str)
... print(a, b)
>>> a = A('1', [1])
Argument b must be (<class 'int'>, <class 'str'>)
"""

def decorate(func):
# If in optimized mode, disable type checking
if not __debug__:
return func

# Map function argument names to supplied types
sig = signature(func)
bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments

@wraps(func)
def wrapper(*args, **kwargs):
bound_values = sig.bind(*args, **kwargs)
# Enforce type assertions across supplied arguments
for name, value in bound_values.arguments.items():
if name in bound_types:
if not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(
name, bound_types[name]))
return func(*args, **kwargs)

return wrapper

return decorate

+ 1
- 0
requirements.txt View File

@@ -1 +1,2 @@
-r requirements/runtime.txt
-r requirements/pipeline.txt

+ 5
- 0
requirements/pipeline.txt View File

@@ -0,0 +1,5 @@
http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl
tensorflow
torch==1.9.1
torchaudio==0.9.1
torchvision==0.10.1

+ 3
- 0
requirements/runtime.txt View File

@@ -1,5 +1,8 @@
addict
numpy
opencv-python-headless
Pillow
pyyaml
requests
transformers
yapf

+ 1
- 1
setup.cfg View File

@@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 120
ignore = F401
ignore = F401,F821
exclude = docs/src,*.pyi,.git

+ 0
- 0
tests/pipelines/__init__.py View File


+ 98
- 0
tests/pipelines/test_base.py View File

@@ -0,0 +1,98 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import PIL

from maas_lib.pipelines import Pipeline, pipeline
from maas_lib.pipelines.builder import PIPELINES
from maas_lib.utils.constant import Tasks
from maas_lib.utils.logger import get_logger
from maas_lib.utils.registry import default_group

logger = get_logger()

Input = Union[str, 'PIL.Image', 'numpy.ndarray']


class CustomPipelineTest(unittest.TestCase):

def test_abstract(self):

@PIPELINES.register_module()
class CustomPipeline1(Pipeline):

def __init__(self,
config_file: str = None,
model=None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)

with self.assertRaises(TypeError):
CustomPipeline1()

def test_custom(self):

@PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='custom-image')
class CustomImagePipeline(Pipeline):

def __init__(self,
config_file: str = None,
model=None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)

def preprocess(self, input: Union[str,
'PIL.Image']) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it

"""
if not isinstance(input, PIL.Image.Image):
from maas_lib.preprocessors import load_image
data_dict = {'img': load_image(input), 'url': input}
else:
data_dict = {'img': input}
return data_dict

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
outputs = {}
if 'url' in inputs:
outputs['filename'] = inputs['url']
img = inputs['img']
new_image = img.resize((img.width // 2, img.height // 2))
outputs['resize_image'] = np.array(new_image)
outputs['dummy_result'] = 'dummy_result'
return outputs

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

self.assertTrue('custom-image' in PIPELINES.modules[default_group])
pipe = pipeline(pipeline_name='custom-image')
pipe2 = pipeline(Tasks.image_tagging)
self.assertTrue(type(pipe) is type(pipe2))

img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
'aliyuncs.com/data/test/images/image1.jpg'
output = pipe(img_url)
self.assertEqual(output['filename'], img_url)
self.assertEqual(output['resize_image'].shape, (318, 512, 3))
self.assertEqual(output['dummy_result'], 'dummy_result')

outputs = pipe([img_url for i in range(4)])
self.assertEqual(len(outputs), 4)
for out in outputs:
self.assertEqual(out['filename'], img_url)
self.assertEqual(out['resize_image'].shape, (318, 512, 3))
self.assertEqual(out['dummy_result'], 'dummy_result')


if __name__ == '__main__':
unittest.main()

+ 32
- 0
tests/pipelines/test_image_matting.py View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import tempfile
import unittest
from typing import Any, Dict, List, Tuple, Union

import cv2
import numpy as np
import PIL

from maas_lib.fileio import File
from maas_lib.pipelines import pipeline
from maas_lib.utils.constant import Tasks


class ImageMattingTest(unittest.TestCase):

def test_run(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb'
with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile:
ofile.write(File.read(model_path))
img_matting = pipeline(Tasks.image_matting, model_path=ofile.name)

result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
cv2.imwrite('result.png', result['output_png'])


if __name__ == '__main__':
unittest.main()

+ 48
- 0
tests/pipelines/test_text_classification.py View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import tempfile
import unittest
import zipfile

from maas_lib.fileio import File
from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.pipelines import SequenceClassificationPipeline
from maas_lib.preprocessors import SequenceClassificationPreprocessor


class SequenceClassificationTest(unittest.TestCase):

def predict(self, pipeline: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset

set = load_dataset('glue', 'sst2')
data = set['test']['sentence'][:3]

results = pipeline(data[0])
print(results)
results = pipeline(data[1])
print(results)

print(data)

def test_run(self):
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip')
with open(tmp_file, 'wb') as ofile:
ofile.write(File.read(model_url))
with zipfile.ZipFile(tmp_file, 'r') as zipf:
zipf.extractall(tmp_dir)
path = osp.join(tmp_dir, 'bert-base-sst2')
print(path)
model = SequenceClassificationModel(path)
preprocessor = SequenceClassificationPreprocessor(
path, first_sequence='sentence', second_sequence=None)
pipeline = SequenceClassificationPipeline(model, preprocessor)
self.predict(pipeline)


if __name__ == '__main__':
unittest.main()

+ 0
- 0
tests/preprocessors/__init__.py View File


+ 39
- 0
tests/preprocessors/test_common.py View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor


class ComposeTest(unittest.TestCase):

def test_compose(self):

@PREPROCESSORS.register_module()
class Tmp1(Preprocessor):

def __call__(self, input):
input['tmp1'] = 'tmp1'
return input

@PREPROCESSORS.register_module()
class Tmp2(Preprocessor):

def __call__(self, input):
input['tmp2'] = 'tmp2'
return input

pipeline = [
dict(type='Tmp1'),
dict(type='Tmp2'),
]
trans = Compose(pipeline)

input = {}
output = trans(input)
self.assertEqual(output['tmp1'], 'tmp1')
self.assertEqual(output['tmp2'], 'tmp2')


if __name__ == '__main__':
unittest.main()

+ 37
- 0
tests/preprocessors/test_nlp.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from maas_lib.preprocessors import build_preprocessor
from maas_lib.utils.constant import Fields, InputFields
from maas_lib.utils.logger import get_logger

logger = get_logger()


class NLPPreprocessorTest(unittest.TestCase):

def test_tokenize(self):
cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased')
preprocessor = build_preprocessor(cfg, Fields.nlp)
input = {
InputFields.text:
'Do not meddle in the affairs of wizards, '
'for they are subtle and quick to anger.'
}
output = preprocessor(input)
self.assertTrue(InputFields.text in output)
self.assertEqual(output['input_ids'], [
101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116,
117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102
])
self.assertEqual(
output['token_type_ids'],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
self.assertEqual(
output['attention_mask'],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


if __name__ == '__main__':
unittest.main()

+ 5
- 3
tests/utils/test_registry.py View File

@@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase):
def test_register_class_no_task(self):
MODELS = Registry('models')
self.assertTrue(MODELS.name == 'models')
self.assertTrue(MODELS.modules == {})
self.assertEqual(len(MODELS.modules), 0)
self.assertTrue(default_group in MODELS.modules)
self.assertTrue(MODELS.modules[default_group] == {})

self.assertEqual(len(MODELS.modules), 1)

@MODELS.register_module(module_name='cls-resnet')
class ResNetForCls(object):
@@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase):
self.assertTrue(Tasks.object_detection in MODELS.modules)
self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR)

self.assertEqual(len(MODELS.modules), 3)
self.assertEqual(len(MODELS.modules), 4)

def test_list(self):
MODELS = Registry('models')


+ 22
- 0
tests/utils/test_type_assert.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest
from typing import List, Union

from maas_lib.utils.type_assert import type_assert


class type_assertTest(unittest.TestCase):

@type_assert(object, list, (int, str))
def a(self, a: List[int], b: Union[int, str]):
print(a, b)

def test_type_assert(self):
with self.assertRaises(TypeError):
self.a([1], 2)
self.a(1, [123])


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save