* add preprocessor module * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline * fix UT failed * add test for Compose Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8769235 * add preprocessor module * add test for Compose * fix citest error * fix abs class error * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * refine models and pipeline interface * add pipeline folder structure * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline 1.add preprossor model pipeline for nlp text classification 2. add corresponding test Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8757371 * new nlp pipeline * format pre-commit code * update easynlp pipeline * update model_name for easynlp pipeline; add test for maas_lib/utils/typeassert.py * update test_typeassert.py * refactor code 1. rename typeassert to type_assert 2. use lazy import to make easynlp dependency optional 3. refine image matting UT * fix linter test failed * update requirements.txt * fix UT failed * fix citest script to update requirementsmaster
@@ -1,4 +1,4 @@ | |||||
pip install -r requirements/runtime.txt | |||||
pip install -r requirements.txt | |||||
pip install -r requirements/tests.txt | pip install -r requirements/tests.txt | ||||
@@ -1 +1,2 @@ | |||||
from .file import File | |||||
from .io import dump, dumps, load | from .io import dump, dumps, load |
@@ -123,6 +123,7 @@ class HTTPStorage(Storage): | |||||
"""HTTP and HTTPS storage.""" | """HTTP and HTTPS storage.""" | ||||
def read(self, url): | def read(self, url): | ||||
# TODO @wenmeng.zwm add progress bar if file is too large | |||||
r = requests.get(url) | r = requests.get(url) | ||||
r.raise_for_status() | r.raise_for_status() | ||||
return r.content | return r.content | ||||
@@ -0,0 +1,4 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from .base import Model | |||||
from .builder import MODELS |
@@ -0,0 +1,29 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from abc import ABC, abstractmethod | |||||
from typing import Dict, List, Tuple, Union | |||||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
class Model(ABC): | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
return self.post_process(self.forward(input)) | |||||
@abstractmethod | |||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
pass | |||||
def post_process(self, input: Dict[str, Tensor], | |||||
**kwargs) -> Dict[str, Tensor]: | |||||
# model specific postprocess, implementation is optional | |||||
# will be called in Pipeline and evaluation loop(in the future) | |||||
return input | |||||
@classmethod | |||||
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||||
raise NotImplementedError('from_preatrained has not been implemented') |
@@ -0,0 +1,22 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from maas_lib.utils.config import ConfigDict | |||||
from maas_lib.utils.constant import Tasks | |||||
from maas_lib.utils.registry import Registry, build_from_cfg | |||||
MODELS = Registry('models') | |||||
def build_model(cfg: ConfigDict, | |||||
task_name: str = None, | |||||
default_args: dict = None): | |||||
""" build model given model config dict | |||||
Args: | |||||
cfg (:obj:`ConfigDict`): config dict for model object. | |||||
task_name (str, optional): task name, refer to | |||||
:obj:`Tasks` for more details | |||||
default_args (dict, optional): Default initialization arguments. | |||||
""" | |||||
return build_from_cfg( | |||||
cfg, MODELS, group_key=task_name, default_args=default_args) |
@@ -0,0 +1 @@ | |||||
from .sequence_classification_model import * # noqa F403 |
@@ -0,0 +1,62 @@ | |||||
from typing import Any, Dict, Optional, Union | |||||
import numpy as np | |||||
import torch | |||||
from maas_lib.utils.constant import Tasks | |||||
from ..base import Model | |||||
from ..builder import MODELS | |||||
__all__ = ['SequenceClassificationModel'] | |||||
@MODELS.register_module( | |||||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
class SequenceClassificationModel(Model): | |||||
def __init__(self, | |||||
model_dir: str, | |||||
model_cls: Optional[Any] = None, | |||||
*args, | |||||
**kwargs): | |||||
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||||
# Predictor.__init__(self, *args, **kwargs) | |||||
"""initilize the sequence classification model from the `model_dir` path | |||||
Args: | |||||
model_dir (str): the model path | |||||
model_cls (Optional[Any], optional): model loader, if None, use the | |||||
default loader to load model weights, by default None | |||||
""" | |||||
super().__init__(model_dir, model_cls, *args, **kwargs) | |||||
from easynlp.appzoo import SequenceClassification | |||||
from easynlp.core.predictor import get_model_predictor | |||||
self.model_dir = model_dir | |||||
model_cls = SequenceClassification if not model_cls else model_cls | |||||
self.model = get_model_predictor( | |||||
model_dir=model_dir, | |||||
model_cls=model_cls, | |||||
input_keys=[('input_ids', torch.LongTensor), | |||||
('attention_mask', torch.LongTensor), | |||||
('token_type_ids', torch.LongTensor)], | |||||
output_keys=['predictions', 'probabilities', 'logits']) | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
"""return the result by the model | |||||
Args: | |||||
input (Dict[str, Any]): the preprocessed data | |||||
Returns: | |||||
Dict[str, np.ndarray]: results | |||||
Example: | |||||
{ | |||||
'predictions': array([1]), # lable 0-negative 1-positive | |||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
} | |||||
""" | |||||
return self.model.predict(input) | |||||
... |
@@ -0,0 +1,6 @@ | |||||
from .audio import * # noqa F403 | |||||
from .base import Pipeline | |||||
from .builder import pipeline | |||||
from .cv import * # noqa F403 | |||||
from .multi_modal import * # noqa F403 | |||||
from .nlp import * # noqa F403 |
@@ -0,0 +1,63 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from abc import ABC, abstractmethod | |||||
from typing import Any, Dict, List, Tuple, Union | |||||
from maas_lib.models import Model | |||||
from maas_lib.preprocessors import Preprocessor | |||||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
output_keys = [ | |||||
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||||
class Pipeline(ABC): | |||||
def __init__(self, | |||||
config_file: str = None, | |||||
model: Model = None, | |||||
preprocessor: Preprocessor = None, | |||||
**kwargs): | |||||
self.model = model | |||||
self.preprocessor = preprocessor | |||||
def __call__(self, input: Union[Input, List[Input]], *args, | |||||
**post_kwargs) -> Dict[str, Any]: | |||||
# moodel provider should leave it as it is | |||||
# maas library developer will handle this function | |||||
# simple show case, need to support iterator type for both tensorflow and pytorch | |||||
# input_dict = self._handle_input(input) | |||||
if isinstance(input, list): | |||||
output = [] | |||||
for ele in input: | |||||
output.append(self._process_single(ele, *args, **post_kwargs)) | |||||
else: | |||||
output = self._process_single(input, *args, **post_kwargs) | |||||
return output | |||||
def _process_single(self, input: Input, *args, | |||||
**post_kwargs) -> Dict[str, Any]: | |||||
out = self.preprocess(input) | |||||
out = self.forward(out) | |||||
out = self.postprocess(out, **post_kwargs) | |||||
return out | |||||
def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it | |||||
""" | |||||
assert self.preprocessor is not None, 'preprocess method should be implemented' | |||||
return self.preprocessor(inputs) | |||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
""" Provide default implementation using self.model and user can reimplement it | |||||
""" | |||||
assert self.model is not None, 'forward method should be implemented' | |||||
return self.model(inputs) | |||||
@abstractmethod | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
raise NotImplementedError('postprocess') |
@@ -0,0 +1,65 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import Union | |||||
from maas_lib.models.base import Model | |||||
from maas_lib.utils.config import ConfigDict | |||||
from maas_lib.utils.constant import Tasks | |||||
from maas_lib.utils.registry import Registry, build_from_cfg | |||||
from .base import Pipeline | |||||
PIPELINES = Registry('pipelines') | |||||
def build_pipeline(cfg: ConfigDict, | |||||
task_name: str = None, | |||||
default_args: dict = None): | |||||
""" build pipeline given model config dict | |||||
Args: | |||||
cfg (:obj:`ConfigDict`): config dict for model object. | |||||
task_name (str, optional): task name, refer to | |||||
:obj:`Tasks` for more details | |||||
default_args (dict, optional): Default initialization arguments. | |||||
""" | |||||
return build_from_cfg( | |||||
cfg, PIPELINES, group_key=task_name, default_args=default_args) | |||||
def pipeline(task: str = None, | |||||
model: Union[str, Model] = None, | |||||
config_file: str = None, | |||||
pipeline_name: str = None, | |||||
framework: str = None, | |||||
device: int = -1, | |||||
**kwargs) -> Pipeline: | |||||
""" Factory method to build a obj:`Pipeline`. | |||||
Args: | |||||
task (str): Task name defining which pipeline will be returned. | |||||
model (str or obj:`Model`): model name or model object. | |||||
config_file (str, optional): path to config file. | |||||
pipeline_name (str, optional): pipeline class name or alias name. | |||||
framework (str, optional): framework type. | |||||
device (int, optional): which device is used to do inference. | |||||
Return: | |||||
pipeline (obj:`Pipeline`): pipeline object for certain task. | |||||
Examples: | |||||
```python | |||||
>>> p = pipeline('image-classification') | |||||
>>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||||
>>> # Using model object | |||||
>>> resnet = Model.from_pretrained('Resnet') | |||||
>>> p = pipeline('image-classification', model=resnet) | |||||
""" | |||||
if task is not None and model is None and pipeline_name is None: | |||||
# get default pipeline for this task | |||||
assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||||
pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||||
if pipeline_name is not None: | |||||
cfg = dict(type=pipeline_name, **kwargs) | |||||
return build_pipeline(cfg, task_name=task) |
@@ -0,0 +1 @@ | |||||
from .image_matting import ImageMatting |
@@ -0,0 +1,67 @@ | |||||
from typing import Any, Dict, List, Tuple, Union | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
import tensorflow as tf | |||||
from cv2 import COLOR_GRAY2RGB | |||||
from maas_lib.pipelines.base import Input | |||||
from maas_lib.preprocessors import load_image | |||||
from maas_lib.utils.constant import Tasks | |||||
from maas_lib.utils.logger import get_logger | |||||
from ..base import Pipeline | |||||
from ..builder import PIPELINES | |||||
if tf.__version__ >= '2.0': | |||||
tf = tf.compat.v1 | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.image_matting, module_name=Tasks.image_matting) | |||||
class ImageMatting(Pipeline): | |||||
def __init__(self, model_path: str): | |||||
super().__init__() | |||||
config = tf.ConfigProto(allow_soft_placement=True) | |||||
config.gpu_options.allow_growth = True | |||||
self._session = tf.Session(config=config) | |||||
with self._session.as_default(): | |||||
logger.info(f'loading model from {model_path}') | |||||
with tf.gfile.FastGFile(model_path, 'rb') as f: | |||||
graph_def = tf.GraphDef() | |||||
graph_def.ParseFromString(f.read()) | |||||
tf.import_graph_def(graph_def, name='') | |||||
self.output = self._session.graph.get_tensor_by_name( | |||||
'output_png:0') | |||||
self.input_name = 'input_image:0' | |||||
logger.info('load model done') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
if isinstance(input, str): | |||||
img = np.array(load_image(input)) | |||||
elif isinstance(input, PIL.Image.Image): | |||||
img = np.array(input.convert('RGB')) | |||||
elif isinstance(input, np.ndarray): | |||||
if len(input.shape) == 2: | |||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
img = input[:, :, ::-1] # in rgb order | |||||
else: | |||||
raise TypeError(f'input should be either str, PIL.Image,' | |||||
f' np.array, but got {type(input)}') | |||||
img = img.astype(np.float) | |||||
result = {'img': img} | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
with self._session.as_default(): | |||||
feed_dict = {self.input_name: input['img']} | |||||
output_png = self._session.run(self.output, feed_dict=feed_dict) | |||||
output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) | |||||
return {'output_png': output_png} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -0,0 +1 @@ | |||||
from .sequence_classification_pipeline import * # noqa F403 |
@@ -0,0 +1,77 @@ | |||||
import os | |||||
import uuid | |||||
from typing import Any, Dict | |||||
import json | |||||
import numpy as np | |||||
from maas_lib.models.nlp import SequenceClassificationModel | |||||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||||
from maas_lib.utils.constant import Tasks | |||||
from ..base import Input, Pipeline | |||||
from ..builder import PIPELINES | |||||
__all__ = ['SequenceClassificationPipeline'] | |||||
@PIPELINES.register_module( | |||||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
class SequenceClassificationPipeline(Pipeline): | |||||
def __init__(self, model: SequenceClassificationModel, | |||||
preprocessor: SequenceClassificationPreprocessor, **kwargs): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
Args: | |||||
model (SequenceClassificationModel): a model instance | |||||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
""" | |||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
from easynlp.utils import io | |||||
self.label_path = os.path.join(model.model_dir, 'label_mapping.json') | |||||
with io.open(self.label_path) as f: | |||||
self.label_mapping = json.load(f) | |||||
self.label_id_to_name = { | |||||
idx: name | |||||
for name, idx in self.label_mapping.items() | |||||
} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
"""process the predict results | |||||
Args: | |||||
inputs (Dict[str, Any]): _description_ | |||||
Returns: | |||||
Dict[str, str]: the predict results | |||||
""" | |||||
probs = inputs['probabilities'] | |||||
logits = inputs['logits'] | |||||
predictions = np.argsort(-probs, axis=-1) | |||||
preds = predictions[0] | |||||
b = 0 | |||||
new_result = list() | |||||
for pred in preds: | |||||
new_result.append({ | |||||
'pred': self.label_id_to_name[pred], | |||||
'prob': float(probs[b][pred]), | |||||
'logit': float(logits[b][pred]) | |||||
}) | |||||
new_results = list() | |||||
new_results.append({ | |||||
'id': | |||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||||
'output': | |||||
new_result, | |||||
'predictions': | |||||
new_result[0]['pred'], | |||||
'probabilities': | |||||
','.join([str(t) for t in inputs['probabilities'][b]]), | |||||
'logits': | |||||
','.join([str(t) for t in inputs['logits'][b]]) | |||||
}) | |||||
return new_results[0] |
@@ -0,0 +1,7 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from .base import Preprocessor | |||||
from .builder import PREPROCESSORS, build_preprocessor | |||||
from .common import Compose | |||||
from .image import LoadImage, load_image | |||||
from .nlp import * # noqa F403 |
@@ -0,0 +1,14 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from abc import ABC, abstractmethod | |||||
from typing import Any, Dict | |||||
class Preprocessor(ABC): | |||||
def __init__(self, *args, **kwargs): | |||||
pass | |||||
@abstractmethod | |||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
pass |
@@ -0,0 +1,22 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from maas_lib.utils.config import ConfigDict | |||||
from maas_lib.utils.constant import Fields | |||||
from maas_lib.utils.registry import Registry, build_from_cfg | |||||
PREPROCESSORS = Registry('preprocessors') | |||||
def build_preprocessor(cfg: ConfigDict, | |||||
field_name: str = None, | |||||
default_args: dict = None): | |||||
""" build preprocesor given model config dict | |||||
Args: | |||||
cfg (:obj:`ConfigDict`): config dict for model object. | |||||
field_name (str, optional): application field name, refer to | |||||
:obj:`Fields` for more details | |||||
default_args (dict, optional): Default initialization arguments. | |||||
""" | |||||
return build_from_cfg( | |||||
cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) |
@@ -0,0 +1,54 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import time | |||||
from collections.abc import Sequence | |||||
from .builder import PREPROCESSORS, build_preprocessor | |||||
@PREPROCESSORS.register_module() | |||||
class Compose(object): | |||||
"""Compose a data pipeline with a sequence of transforms. | |||||
Args: | |||||
transforms (list[dict | callable]): | |||||
Either config dicts of transforms or transform objects. | |||||
profiling (bool, optional): If set True, will profile and | |||||
print preprocess time for each step. | |||||
""" | |||||
def __init__(self, transforms, field_name=None, profiling=False): | |||||
assert isinstance(transforms, Sequence) | |||||
self.profiling = profiling | |||||
self.transforms = [] | |||||
self.field_name = field_name | |||||
for transform in transforms: | |||||
if isinstance(transform, dict): | |||||
if self.field_name is None: | |||||
transform = build_preprocessor(transform, field_name) | |||||
self.transforms.append(transform) | |||||
elif callable(transform): | |||||
self.transforms.append(transform) | |||||
else: | |||||
raise TypeError('transform must be callable or a dict, but got' | |||||
f' {type(transform)}') | |||||
def __call__(self, data): | |||||
for t in self.transforms: | |||||
if self.profiling: | |||||
start = time.time() | |||||
data = t(data) | |||||
if self.profiling: | |||||
print(f'{t} time {time.time()-start}') | |||||
if data is None: | |||||
return None | |||||
return data | |||||
def __repr__(self): | |||||
format_string = self.__class__.__name__ + '(' | |||||
for t in self.transforms: | |||||
format_string += f'\n {t}' | |||||
format_string += '\n)' | |||||
return format_string |
@@ -0,0 +1,70 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import io | |||||
from typing import Dict, Union | |||||
from PIL import Image, ImageOps | |||||
from maas_lib.fileio import File | |||||
from maas_lib.utils.constant import Fields | |||||
from .builder import PREPROCESSORS | |||||
@PREPROCESSORS.register_module(Fields.image) | |||||
class LoadImage: | |||||
"""Load an image from file or url. | |||||
Added or updated keys are "filename", "img", "img_shape", | |||||
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||||
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||||
Args: | |||||
mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`. | |||||
to_float32 (bool): Whether to convert the loaded image to a float32 | |||||
numpy array. If set to False, the loaded image is an uint8 array. | |||||
Defaults to False. | |||||
""" | |||||
def __init__(self, mode='rgb'): | |||||
self.mode = mode.upper() | |||||
def __call__(self, input: Union[str, Dict[str, str]]): | |||||
"""Call functions to load image and get image meta information. | |||||
Args: | |||||
input (str or dict): input image path or input dict with | |||||
a key `filename`. | |||||
Returns: | |||||
dict: The dict contains loaded image. | |||||
""" | |||||
if isinstance(input, dict): | |||||
image_path_or_url = input['filename'] | |||||
else: | |||||
image_path_or_url = input | |||||
bytes = File.read(image_path_or_url) | |||||
# TODO @wenmeng.zwm add opencv decode as optional | |||||
# we should also look at the input format which is the most commonly | |||||
# used in Mind' image related models | |||||
with io.BytesIO(bytes) as infile: | |||||
img = Image.open(infile) | |||||
img = ImageOps.exif_transpose(img) | |||||
img = img.convert(self.mode) | |||||
results = { | |||||
'filename': image_path_or_url, | |||||
'img': img, | |||||
'img_shape': (img.size[1], img.size[0], 3), | |||||
'img_field': 'img', | |||||
} | |||||
return results | |||||
def __repr__(self): | |||||
repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})') | |||||
return repr_str | |||||
def load_image(image_path_or_url: str) -> Image: | |||||
""" simple interface to load an image from file or url | |||||
Args: | |||||
image_path_or_url (str): image file path or http url | |||||
""" | |||||
loader = LoadImage() | |||||
return loader(image_path_or_url)['img'] |
@@ -0,0 +1,91 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import uuid | |||||
from typing import Any, Dict, Union | |||||
from transformers import AutoTokenizer | |||||
from maas_lib.utils.constant import Fields, InputFields | |||||
from maas_lib.utils.type_assert import type_assert | |||||
from .base import Preprocessor | |||||
from .builder import PREPROCESSORS | |||||
__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||||
@PREPROCESSORS.register_module(Fields.nlp) | |||||
class Tokenize(Preprocessor): | |||||
def __init__(self, tokenizer_name) -> None: | |||||
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||||
def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: | |||||
if isinstance(data, str): | |||||
data = {InputFields.text: data} | |||||
token_dict = self._tokenizer(data[InputFields.text]) | |||||
data.update(token_dict) | |||||
return data | |||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=r'bert-sentiment-analysis') | |||||
class SequenceClassificationPreprocessor(Preprocessor): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||||
Args: | |||||
model_dir (str): model path | |||||
""" | |||||
super().__init__(*args, **kwargs) | |||||
from easynlp.modelzoo import AutoTokenizer | |||||
self.model_dir: str = model_dir | |||||
self.first_sequence: str = kwargs.pop('first_sequence', | |||||
'first_sequence') | |||||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||||
@type_assert(object, str) | |||||
def __call__(self, data: str) -> Dict[str, Any]: | |||||
"""process the raw input data | |||||
Args: | |||||
data (str): a sentence | |||||
Example: | |||||
'you are so handsome.' | |||||
Returns: | |||||
Dict[str, Any]: the preprocessed data | |||||
""" | |||||
new_data = {self.first_sequence: data} | |||||
# preprocess the data for the model input | |||||
rst = { | |||||
'id': [], | |||||
'input_ids': [], | |||||
'attention_mask': [], | |||||
'token_type_ids': [] | |||||
} | |||||
max_seq_length = self.sequence_length | |||||
text_a = new_data[self.first_sequence] | |||||
text_b = new_data.get(self.second_sequence, None) | |||||
feature = self.tokenizer( | |||||
text_a, | |||||
text_b, | |||||
padding='max_length', | |||||
truncation=True, | |||||
max_length=max_seq_length) | |||||
rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
rst['input_ids'].append(feature['input_ids']) | |||||
rst['attention_mask'].append(feature['attention_mask']) | |||||
rst['token_type_ids'].append(feature['token_type_ids']) | |||||
return rst |
@@ -6,6 +6,7 @@ class Fields(object): | |||||
""" | """ | ||||
image = 'image' | image = 'image' | ||||
video = 'video' | video = 'video' | ||||
cv = 'cv' | |||||
nlp = 'nlp' | nlp = 'nlp' | ||||
audio = 'audio' | audio = 'audio' | ||||
multi_modal = 'multi_modal' | multi_modal = 'multi_modal' | ||||
@@ -18,12 +19,41 @@ class Tasks(object): | |||||
This should be used to register models, pipelines, trainers. | This should be used to register models, pipelines, trainers. | ||||
""" | """ | ||||
# vision tasks | # vision tasks | ||||
image_to_text = 'image-to-text' | |||||
pose_estimation = 'pose-estimation' | |||||
image_classfication = 'image-classification' | image_classfication = 'image-classification' | ||||
image_tagging = 'image-tagging' | |||||
object_detection = 'object-detection' | object_detection = 'object-detection' | ||||
image_segmentation = 'image-segmentation' | |||||
image_editing = 'image-editing' | |||||
image_generation = 'image-generation' | |||||
image_matting = 'image-matting' | |||||
# nlp tasks | # nlp tasks | ||||
sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
fill_mask = 'fill-mask' | |||||
text_classification = 'text-classification' | |||||
relation_extraction = 'relation-extraction' | |||||
zero_shot = 'zero-shot' | |||||
translation = 'translation' | |||||
token_classificatio = 'token-classification' | |||||
conversational = 'conversational' | |||||
text_generation = 'text-generation' | |||||
table_question_answ = 'table-question-answering' | |||||
feature_extraction = 'feature-extraction' | |||||
sentence_similarity = 'sentence-similarity' | |||||
fill_mask = 'fill-mask ' | |||||
summarization = 'summarization' | |||||
question_answering = 'question-answering' | |||||
# audio tasks | |||||
auto_speech_recognition = 'auto-speech-recognition' | |||||
text_to_speech = 'text-to-speech' | |||||
speech_signal_process = 'speech-signal-process' | |||||
# multi-media | |||||
image_captioning = 'image-captioning' | |||||
visual_grounding = 'visual-grounding' | |||||
text_to_image_synthesis = 'text-to-image-synthesis' | |||||
class InputFields(object): | class InputFields(object): | ||||
@@ -1,5 +1,7 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import inspect | import inspect | ||||
from email.policy import default | |||||
from maas_lib.utils.logger import get_logger | from maas_lib.utils.logger import get_logger | ||||
@@ -15,10 +17,10 @@ class Registry(object): | |||||
def __init__(self, name: str): | def __init__(self, name: str): | ||||
self._name = name | self._name = name | ||||
self._modules = dict() | |||||
self._modules = {default_group: {}} | |||||
def __repr__(self): | def __repr__(self): | ||||
format_str = self.__class__.__name__ + f'({self._name})\n' | |||||
format_str = self.__class__.__name__ + f' ({self._name})\n' | |||||
for group_name, group in self._modules.items(): | for group_name, group in self._modules.items(): | ||||
format_str += f'group_name={group_name}, '\ | format_str += f'group_name={group_name}, '\ | ||||
f'modules={list(group.keys())}\n' | f'modules={list(group.keys())}\n' | ||||
@@ -64,11 +66,24 @@ class Registry(object): | |||||
module_name = module_cls.__name__ | module_name = module_cls.__name__ | ||||
if module_name in self._modules[group_key]: | if module_name in self._modules[group_key]: | ||||
raise KeyError(f'{module_name} is already registered in' | |||||
raise KeyError(f'{module_name} is already registered in ' | |||||
f'{self._name}[{group_key}]') | f'{self._name}[{group_key}]') | ||||
self._modules[group_key][module_name] = module_cls | self._modules[group_key][module_name] = module_cls | ||||
if module_name in self._modules[default_group]: | |||||
if id(self._modules[default_group][module_name]) == id(module_cls): | |||||
return | |||||
else: | |||||
logger.warning(f'{module_name} is already registered in ' | |||||
f'{self._name}[{default_group}] and will ' | |||||
'be overwritten') | |||||
logger.warning(f'{self._modules[default_group][module_name]}' | |||||
'to {module_cls}') | |||||
# also register module in the default group for faster access | |||||
# only by module name | |||||
self._modules[default_group][module_name] = module_cls | |||||
def register_module(self, | def register_module(self, | ||||
group_key: str = default_group, | group_key: str = default_group, | ||||
module_name: str = None, | module_name: str = None, | ||||
@@ -165,12 +180,15 @@ def build_from_cfg(cfg, | |||||
for name, value in default_args.items(): | for name, value in default_args.items(): | ||||
args.setdefault(name, value) | args.setdefault(name, value) | ||||
if group_key is None: | |||||
group_key = default_group | |||||
obj_type = args.pop('type') | obj_type = args.pop('type') | ||||
if isinstance(obj_type, str): | if isinstance(obj_type, str): | ||||
obj_cls = registry.get(obj_type, group_key=group_key) | obj_cls = registry.get(obj_type, group_key=group_key) | ||||
if obj_cls is None: | if obj_cls is None: | ||||
raise KeyError(f'{obj_type} is not in the {registry.name}' | raise KeyError(f'{obj_type} is not in the {registry.name}' | ||||
f'registry group {group_key}') | |||||
f' registry group {group_key}') | |||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | ||||
obj_cls = obj_type | obj_cls = obj_type | ||||
else: | else: | ||||
@@ -0,0 +1,50 @@ | |||||
from functools import wraps | |||||
from inspect import signature | |||||
def type_assert(*ty_args, **ty_kwargs): | |||||
"""a decorator which is used to check the types of arguments in a function or class | |||||
Examples: | |||||
>>> @type_assert(str) | |||||
... def main(a: str, b: list): | |||||
... print(a, b) | |||||
>>> main(1) | |||||
Argument a must be a str | |||||
>>> @type_assert(str, (int, str)) | |||||
... def main(a: str, b: int | str): | |||||
... print(a, b) | |||||
>>> main('1', [1]) | |||||
Argument b must be (<class 'int'>, <class 'str'>) | |||||
>>> @type_assert(str, (int, str)) | |||||
... class A: | |||||
... def __init__(self, a: str, b: int | str) | |||||
... print(a, b) | |||||
>>> a = A('1', [1]) | |||||
Argument b must be (<class 'int'>, <class 'str'>) | |||||
""" | |||||
def decorate(func): | |||||
# If in optimized mode, disable type checking | |||||
if not __debug__: | |||||
return func | |||||
# Map function argument names to supplied types | |||||
sig = signature(func) | |||||
bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments | |||||
@wraps(func) | |||||
def wrapper(*args, **kwargs): | |||||
bound_values = sig.bind(*args, **kwargs) | |||||
# Enforce type assertions across supplied arguments | |||||
for name, value in bound_values.arguments.items(): | |||||
if name in bound_types: | |||||
if not isinstance(value, bound_types[name]): | |||||
raise TypeError('Argument {} must be {}'.format( | |||||
name, bound_types[name])) | |||||
return func(*args, **kwargs) | |||||
return wrapper | |||||
return decorate |
@@ -1 +1,2 @@ | |||||
-r requirements/runtime.txt | -r requirements/runtime.txt | ||||
-r requirements/pipeline.txt |
@@ -0,0 +1,5 @@ | |||||
http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||||
tensorflow | |||||
torch==1.9.1 | |||||
torchaudio==0.9.1 | |||||
torchvision==0.10.1 |
@@ -1,5 +1,8 @@ | |||||
addict | addict | ||||
numpy | numpy | ||||
opencv-python-headless | |||||
Pillow | |||||
pyyaml | pyyaml | ||||
requests | requests | ||||
transformers | |||||
yapf | yapf |
@@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||||
[flake8] | [flake8] | ||||
select = B,C,E,F,P,T4,W,B9 | select = B,C,E,F,P,T4,W,B9 | ||||
max-line-length = 120 | max-line-length = 120 | ||||
ignore = F401 | |||||
ignore = F401,F821 | |||||
exclude = docs/src,*.pyi,.git | exclude = docs/src,*.pyi,.git |
@@ -0,0 +1,98 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from typing import Any, Dict, List, Tuple, Union | |||||
import numpy as np | |||||
import PIL | |||||
from maas_lib.pipelines import Pipeline, pipeline | |||||
from maas_lib.pipelines.builder import PIPELINES | |||||
from maas_lib.utils.constant import Tasks | |||||
from maas_lib.utils.logger import get_logger | |||||
from maas_lib.utils.registry import default_group | |||||
logger = get_logger() | |||||
Input = Union[str, 'PIL.Image', 'numpy.ndarray'] | |||||
class CustomPipelineTest(unittest.TestCase): | |||||
def test_abstract(self): | |||||
@PIPELINES.register_module() | |||||
class CustomPipeline1(Pipeline): | |||||
def __init__(self, | |||||
config_file: str = None, | |||||
model=None, | |||||
preprocessor=None, | |||||
**kwargs): | |||||
super().__init__(config_file, model, preprocessor, **kwargs) | |||||
with self.assertRaises(TypeError): | |||||
CustomPipeline1() | |||||
def test_custom(self): | |||||
@PIPELINES.register_module( | |||||
group_key=Tasks.image_tagging, module_name='custom-image') | |||||
class CustomImagePipeline(Pipeline): | |||||
def __init__(self, | |||||
config_file: str = None, | |||||
model=None, | |||||
preprocessor=None, | |||||
**kwargs): | |||||
super().__init__(config_file, model, preprocessor, **kwargs) | |||||
def preprocess(self, input: Union[str, | |||||
'PIL.Image']) -> Dict[str, Any]: | |||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it | |||||
""" | |||||
if not isinstance(input, PIL.Image.Image): | |||||
from maas_lib.preprocessors import load_image | |||||
data_dict = {'img': load_image(input), 'url': input} | |||||
else: | |||||
data_dict = {'img': input} | |||||
return data_dict | |||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
""" Provide default implementation using self.model and user can reimplement it | |||||
""" | |||||
outputs = {} | |||||
if 'url' in inputs: | |||||
outputs['filename'] = inputs['url'] | |||||
img = inputs['img'] | |||||
new_image = img.resize((img.width // 2, img.height // 2)) | |||||
outputs['resize_image'] = np.array(new_image) | |||||
outputs['dummy_result'] = 'dummy_result' | |||||
return outputs | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs | |||||
self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||||
pipe = pipeline(pipeline_name='custom-image') | |||||
pipe2 = pipeline(Tasks.image_tagging) | |||||
self.assertTrue(type(pipe) is type(pipe2)) | |||||
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||||
'aliyuncs.com/data/test/images/image1.jpg' | |||||
output = pipe(img_url) | |||||
self.assertEqual(output['filename'], img_url) | |||||
self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||||
self.assertEqual(output['dummy_result'], 'dummy_result') | |||||
outputs = pipe([img_url for i in range(4)]) | |||||
self.assertEqual(len(outputs), 4) | |||||
for out in outputs: | |||||
self.assertEqual(out['filename'], img_url) | |||||
self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||||
self.assertEqual(out['dummy_result'], 'dummy_result') | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,32 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import tempfile | |||||
import unittest | |||||
from typing import Any, Dict, List, Tuple, Union | |||||
import cv2 | |||||
import numpy as np | |||||
import PIL | |||||
from maas_lib.fileio import File | |||||
from maas_lib.pipelines import pipeline | |||||
from maas_lib.utils.constant import Tasks | |||||
class ImageMattingTest(unittest.TestCase): | |||||
def test_run(self): | |||||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||||
'.com/data/test/maas/image_matting/matting_person.pb' | |||||
with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||||
ofile.write(File.read(model_path)) | |||||
img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||||
result = img_matting( | |||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||||
) | |||||
cv2.imwrite('result.png', result['output_png']) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,48 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import os.path as osp | |||||
import tempfile | |||||
import unittest | |||||
import zipfile | |||||
from maas_lib.fileio import File | |||||
from maas_lib.models.nlp import SequenceClassificationModel | |||||
from maas_lib.pipelines import SequenceClassificationPipeline | |||||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||||
class SequenceClassificationTest(unittest.TestCase): | |||||
def predict(self, pipeline: SequenceClassificationPipeline): | |||||
from easynlp.appzoo import load_dataset | |||||
set = load_dataset('glue', 'sst2') | |||||
data = set['test']['sentence'][:3] | |||||
results = pipeline(data[0]) | |||||
print(results) | |||||
results = pipeline(data[1]) | |||||
print(results) | |||||
print(data) | |||||
def test_run(self): | |||||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||||
with tempfile.TemporaryDirectory() as tmp_dir: | |||||
tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||||
with open(tmp_file, 'wb') as ofile: | |||||
ofile.write(File.read(model_url)) | |||||
with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||||
zipf.extractall(tmp_dir) | |||||
path = osp.join(tmp_dir, 'bert-base-sst2') | |||||
print(path) | |||||
model = SequenceClassificationModel(path) | |||||
preprocessor = SequenceClassificationPreprocessor( | |||||
path, first_sequence='sentence', second_sequence=None) | |||||
pipeline = SequenceClassificationPipeline(model, preprocessor) | |||||
self.predict(pipeline) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,39 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor | |||||
class ComposeTest(unittest.TestCase): | |||||
def test_compose(self): | |||||
@PREPROCESSORS.register_module() | |||||
class Tmp1(Preprocessor): | |||||
def __call__(self, input): | |||||
input['tmp1'] = 'tmp1' | |||||
return input | |||||
@PREPROCESSORS.register_module() | |||||
class Tmp2(Preprocessor): | |||||
def __call__(self, input): | |||||
input['tmp2'] = 'tmp2' | |||||
return input | |||||
pipeline = [ | |||||
dict(type='Tmp1'), | |||||
dict(type='Tmp2'), | |||||
] | |||||
trans = Compose(pipeline) | |||||
input = {} | |||||
output = trans(input) | |||||
self.assertEqual(output['tmp1'], 'tmp1') | |||||
self.assertEqual(output['tmp2'], 'tmp2') | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,37 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from maas_lib.preprocessors import build_preprocessor | |||||
from maas_lib.utils.constant import Fields, InputFields | |||||
from maas_lib.utils.logger import get_logger | |||||
logger = get_logger() | |||||
class NLPPreprocessorTest(unittest.TestCase): | |||||
def test_tokenize(self): | |||||
cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') | |||||
preprocessor = build_preprocessor(cfg, Fields.nlp) | |||||
input = { | |||||
InputFields.text: | |||||
'Do not meddle in the affairs of wizards, ' | |||||
'for they are subtle and quick to anger.' | |||||
} | |||||
output = preprocessor(input) | |||||
self.assertTrue(InputFields.text in output) | |||||
self.assertEqual(output['input_ids'], [ | |||||
101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, | |||||
117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 | |||||
]) | |||||
self.assertEqual( | |||||
output['token_type_ids'], | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |||||
self.assertEqual( | |||||
output['attention_mask'], | |||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase): | |||||
def test_register_class_no_task(self): | def test_register_class_no_task(self): | ||||
MODELS = Registry('models') | MODELS = Registry('models') | ||||
self.assertTrue(MODELS.name == 'models') | self.assertTrue(MODELS.name == 'models') | ||||
self.assertTrue(MODELS.modules == {}) | |||||
self.assertEqual(len(MODELS.modules), 0) | |||||
self.assertTrue(default_group in MODELS.modules) | |||||
self.assertTrue(MODELS.modules[default_group] == {}) | |||||
self.assertEqual(len(MODELS.modules), 1) | |||||
@MODELS.register_module(module_name='cls-resnet') | @MODELS.register_module(module_name='cls-resnet') | ||||
class ResNetForCls(object): | class ResNetForCls(object): | ||||
@@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase): | |||||
self.assertTrue(Tasks.object_detection in MODELS.modules) | self.assertTrue(Tasks.object_detection in MODELS.modules) | ||||
self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | ||||
self.assertEqual(len(MODELS.modules), 3) | |||||
self.assertEqual(len(MODELS.modules), 4) | |||||
def test_list(self): | def test_list(self): | ||||
MODELS = Registry('models') | MODELS = Registry('models') | ||||
@@ -0,0 +1,22 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from typing import List, Union | |||||
from maas_lib.utils.type_assert import type_assert | |||||
class type_assertTest(unittest.TestCase): | |||||
@type_assert(object, list, (int, str)) | |||||
def a(self, a: List[int], b: Union[int, str]): | |||||
print(a, b) | |||||
def test_type_assert(self): | |||||
with self.assertRaises(TypeError): | |||||
self.a([1], 2) | |||||
self.a(1, [123]) | |||||
if __name__ == '__main__': | |||||
unittest.main() |