* add preprocessor module * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline * fix UT failed * add test for Compose Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8769235 * add preprocessor module * add test for Compose * fix citest error * fix abs class error * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * refine models and pipeline interface * add pipeline folder structure * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline 1.add preprossor model pipeline for nlp text classification 2. add corresponding test Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8757371 * new nlp pipeline * format pre-commit code * update easynlp pipeline * update model_name for easynlp pipeline; add test for maas_lib/utils/typeassert.py * update test_typeassert.py * refactor code 1. rename typeassert to type_assert 2. use lazy import to make easynlp dependency optional 3. refine image matting UT * fix linter test failed * update requirements.txt * fix UT failed * fix citest script to update requirementsmaster
@@ -1,4 +1,4 @@ | |||
pip install -r requirements/runtime.txt | |||
pip install -r requirements.txt | |||
pip install -r requirements/tests.txt | |||
@@ -1 +1,2 @@ | |||
from .file import File | |||
from .io import dump, dumps, load |
@@ -123,6 +123,7 @@ class HTTPStorage(Storage): | |||
"""HTTP and HTTPS storage.""" | |||
def read(self, url): | |||
# TODO @wenmeng.zwm add progress bar if file is too large | |||
r = requests.get(url) | |||
r.raise_for_status() | |||
return r.content | |||
@@ -0,0 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .base import Model | |||
from .builder import MODELS |
@@ -0,0 +1,29 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import ABC, abstractmethod | |||
from typing import Dict, List, Tuple, Union | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
class Model(ABC): | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
return self.post_process(self.forward(input)) | |||
@abstractmethod | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
pass | |||
def post_process(self, input: Dict[str, Tensor], | |||
**kwargs) -> Dict[str, Tensor]: | |||
# model specific postprocess, implementation is optional | |||
# will be called in Pipeline and evaluation loop(in the future) | |||
return input | |||
@classmethod | |||
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||
raise NotImplementedError('from_preatrained has not been implemented') |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from maas_lib.utils.config import ConfigDict | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.registry import Registry, build_from_cfg | |||
MODELS = Registry('models') | |||
def build_model(cfg: ConfigDict, | |||
task_name: str = None, | |||
default_args: dict = None): | |||
""" build model given model config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for model object. | |||
task_name (str, optional): task name, refer to | |||
:obj:`Tasks` for more details | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, MODELS, group_key=task_name, default_args=default_args) |
@@ -0,0 +1 @@ | |||
from .sequence_classification_model import * # noqa F403 |
@@ -0,0 +1,62 @@ | |||
from typing import Any, Dict, Optional, Union | |||
import numpy as np | |||
import torch | |||
from maas_lib.utils.constant import Tasks | |||
from ..base import Model | |||
from ..builder import MODELS | |||
__all__ = ['SequenceClassificationModel'] | |||
@MODELS.register_module( | |||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
class SequenceClassificationModel(Model): | |||
def __init__(self, | |||
model_dir: str, | |||
model_cls: Optional[Any] = None, | |||
*args, | |||
**kwargs): | |||
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
# Predictor.__init__(self, *args, **kwargs) | |||
"""initilize the sequence classification model from the `model_dir` path | |||
Args: | |||
model_dir (str): the model path | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None | |||
""" | |||
super().__init__(model_dir, model_cls, *args, **kwargs) | |||
from easynlp.appzoo import SequenceClassification | |||
from easynlp.core.predictor import get_model_predictor | |||
self.model_dir = model_dir | |||
model_cls = SequenceClassification if not model_cls else model_cls | |||
self.model = get_model_predictor( | |||
model_dir=model_dir, | |||
model_cls=model_cls, | |||
input_keys=[('input_ids', torch.LongTensor), | |||
('attention_mask', torch.LongTensor), | |||
('token_type_ids', torch.LongTensor)], | |||
output_keys=['predictions', 'probabilities', 'logits']) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
Returns: | |||
Dict[str, np.ndarray]: results | |||
Example: | |||
{ | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
} | |||
""" | |||
return self.model.predict(input) | |||
... |
@@ -0,0 +1,6 @@ | |||
from .audio import * # noqa F403 | |||
from .base import Pipeline | |||
from .builder import pipeline | |||
from .cv import * # noqa F403 | |||
from .multi_modal import * # noqa F403 | |||
from .nlp import * # noqa F403 |
@@ -0,0 +1,63 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Dict, List, Tuple, Union | |||
from maas_lib.models import Model | |||
from maas_lib.preprocessors import Preprocessor | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||
output_keys = [ | |||
] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||
class Pipeline(ABC): | |||
def __init__(self, | |||
config_file: str = None, | |||
model: Model = None, | |||
preprocessor: Preprocessor = None, | |||
**kwargs): | |||
self.model = model | |||
self.preprocessor = preprocessor | |||
def __call__(self, input: Union[Input, List[Input]], *args, | |||
**post_kwargs) -> Dict[str, Any]: | |||
# moodel provider should leave it as it is | |||
# maas library developer will handle this function | |||
# simple show case, need to support iterator type for both tensorflow and pytorch | |||
# input_dict = self._handle_input(input) | |||
if isinstance(input, list): | |||
output = [] | |||
for ele in input: | |||
output.append(self._process_single(ele, *args, **post_kwargs)) | |||
else: | |||
output = self._process_single(input, *args, **post_kwargs) | |||
return output | |||
def _process_single(self, input: Input, *args, | |||
**post_kwargs) -> Dict[str, Any]: | |||
out = self.preprocess(input) | |||
out = self.forward(out) | |||
out = self.postprocess(out, **post_kwargs) | |||
return out | |||
def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
""" Provide default implementation based on preprocess_cfg and user can reimplement it | |||
""" | |||
assert self.preprocessor is not None, 'preprocess method should be implemented' | |||
return self.preprocessor(inputs) | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
""" Provide default implementation using self.model and user can reimplement it | |||
""" | |||
assert self.model is not None, 'forward method should be implemented' | |||
return self.model(inputs) | |||
@abstractmethod | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
raise NotImplementedError('postprocess') |
@@ -0,0 +1,65 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Union | |||
from maas_lib.models.base import Model | |||
from maas_lib.utils.config import ConfigDict | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.registry import Registry, build_from_cfg | |||
from .base import Pipeline | |||
PIPELINES = Registry('pipelines') | |||
def build_pipeline(cfg: ConfigDict, | |||
task_name: str = None, | |||
default_args: dict = None): | |||
""" build pipeline given model config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for model object. | |||
task_name (str, optional): task name, refer to | |||
:obj:`Tasks` for more details | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, PIPELINES, group_key=task_name, default_args=default_args) | |||
def pipeline(task: str = None, | |||
model: Union[str, Model] = None, | |||
config_file: str = None, | |||
pipeline_name: str = None, | |||
framework: str = None, | |||
device: int = -1, | |||
**kwargs) -> Pipeline: | |||
""" Factory method to build a obj:`Pipeline`. | |||
Args: | |||
task (str): Task name defining which pipeline will be returned. | |||
model (str or obj:`Model`): model name or model object. | |||
config_file (str, optional): path to config file. | |||
pipeline_name (str, optional): pipeline class name or alias name. | |||
framework (str, optional): framework type. | |||
device (int, optional): which device is used to do inference. | |||
Return: | |||
pipeline (obj:`Pipeline`): pipeline object for certain task. | |||
Examples: | |||
```python | |||
>>> p = pipeline('image-classification') | |||
>>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||
>>> # Using model object | |||
>>> resnet = Model.from_pretrained('Resnet') | |||
>>> p = pipeline('image-classification', model=resnet) | |||
""" | |||
if task is not None and model is None and pipeline_name is None: | |||
# get default pipeline for this task | |||
assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||
if pipeline_name is not None: | |||
cfg = dict(type=pipeline_name, **kwargs) | |||
return build_pipeline(cfg, task_name=task) |
@@ -0,0 +1 @@ | |||
from .image_matting import ImageMatting |
@@ -0,0 +1,67 @@ | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import tensorflow as tf | |||
from cv2 import COLOR_GRAY2RGB | |||
from maas_lib.pipelines.base import Input | |||
from maas_lib.preprocessors import load_image | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.logger import get_logger | |||
from ..base import Pipeline | |||
from ..builder import PIPELINES | |||
if tf.__version__ >= '2.0': | |||
tf = tf.compat.v1 | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_matting, module_name=Tasks.image_matting) | |||
class ImageMatting(Pipeline): | |||
def __init__(self, model_path: str): | |||
super().__init__() | |||
config = tf.ConfigProto(allow_soft_placement=True) | |||
config.gpu_options.allow_growth = True | |||
self._session = tf.Session(config=config) | |||
with self._session.as_default(): | |||
logger.info(f'loading model from {model_path}') | |||
with tf.gfile.FastGFile(model_path, 'rb') as f: | |||
graph_def = tf.GraphDef() | |||
graph_def.ParseFromString(f.read()) | |||
tf.import_graph_def(graph_def, name='') | |||
self.output = self._session.graph.get_tensor_by_name( | |||
'output_png:0') | |||
self.input_name = 'input_image:0' | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
img = np.array(load_image(input)) | |||
elif isinstance(input, PIL.Image.Image): | |||
img = np.array(input.convert('RGB')) | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
img = input[:, :, ::-1] # in rgb order | |||
else: | |||
raise TypeError(f'input should be either str, PIL.Image,' | |||
f' np.array, but got {type(input)}') | |||
img = img.astype(np.float) | |||
result = {'img': img} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
with self._session.as_default(): | |||
feed_dict = {self.input_name: input['img']} | |||
output_png = self._session.run(self.output, feed_dict=feed_dict) | |||
output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) | |||
return {'output_png': output_png} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -0,0 +1 @@ | |||
from .sequence_classification_pipeline import * # noqa F403 |
@@ -0,0 +1,77 @@ | |||
import os | |||
import uuid | |||
from typing import Any, Dict | |||
import json | |||
import numpy as np | |||
from maas_lib.models.nlp import SequenceClassificationModel | |||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
from maas_lib.utils.constant import Tasks | |||
from ..base import Input, Pipeline | |||
from ..builder import PIPELINES | |||
__all__ = ['SequenceClassificationPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
class SequenceClassificationPipeline(Pipeline): | |||
def __init__(self, model: SequenceClassificationModel, | |||
preprocessor: SequenceClassificationPreprocessor, **kwargs): | |||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
Args: | |||
model (SequenceClassificationModel): a model instance | |||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
from easynlp.utils import io | |||
self.label_path = os.path.join(model.model_dir, 'label_mapping.json') | |||
with io.open(self.label_path) as f: | |||
self.label_mapping = json.load(f) | |||
self.label_id_to_name = { | |||
idx: name | |||
for name, idx in self.label_mapping.items() | |||
} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
"""process the predict results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the predict results | |||
""" | |||
probs = inputs['probabilities'] | |||
logits = inputs['logits'] | |||
predictions = np.argsort(-probs, axis=-1) | |||
preds = predictions[0] | |||
b = 0 | |||
new_result = list() | |||
for pred in preds: | |||
new_result.append({ | |||
'pred': self.label_id_to_name[pred], | |||
'prob': float(probs[b][pred]), | |||
'logit': float(logits[b][pred]) | |||
}) | |||
new_results = list() | |||
new_results.append({ | |||
'id': | |||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||
'output': | |||
new_result, | |||
'predictions': | |||
new_result[0]['pred'], | |||
'probabilities': | |||
','.join([str(t) for t in inputs['probabilities'][b]]), | |||
'logits': | |||
','.join([str(t) for t in inputs['logits'][b]]) | |||
}) | |||
return new_results[0] |
@@ -0,0 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS, build_preprocessor | |||
from .common import Compose | |||
from .image import LoadImage, load_image | |||
from .nlp import * # noqa F403 |
@@ -0,0 +1,14 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import ABC, abstractmethod | |||
from typing import Any, Dict | |||
class Preprocessor(ABC): | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
@abstractmethod | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
pass |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from maas_lib.utils.config import ConfigDict | |||
from maas_lib.utils.constant import Fields | |||
from maas_lib.utils.registry import Registry, build_from_cfg | |||
PREPROCESSORS = Registry('preprocessors') | |||
def build_preprocessor(cfg: ConfigDict, | |||
field_name: str = None, | |||
default_args: dict = None): | |||
""" build preprocesor given model config dict | |||
Args: | |||
cfg (:obj:`ConfigDict`): config dict for model object. | |||
field_name (str, optional): application field name, refer to | |||
:obj:`Fields` for more details | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
return build_from_cfg( | |||
cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) |
@@ -0,0 +1,54 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import time | |||
from collections.abc import Sequence | |||
from .builder import PREPROCESSORS, build_preprocessor | |||
@PREPROCESSORS.register_module() | |||
class Compose(object): | |||
"""Compose a data pipeline with a sequence of transforms. | |||
Args: | |||
transforms (list[dict | callable]): | |||
Either config dicts of transforms or transform objects. | |||
profiling (bool, optional): If set True, will profile and | |||
print preprocess time for each step. | |||
""" | |||
def __init__(self, transforms, field_name=None, profiling=False): | |||
assert isinstance(transforms, Sequence) | |||
self.profiling = profiling | |||
self.transforms = [] | |||
self.field_name = field_name | |||
for transform in transforms: | |||
if isinstance(transform, dict): | |||
if self.field_name is None: | |||
transform = build_preprocessor(transform, field_name) | |||
self.transforms.append(transform) | |||
elif callable(transform): | |||
self.transforms.append(transform) | |||
else: | |||
raise TypeError('transform must be callable or a dict, but got' | |||
f' {type(transform)}') | |||
def __call__(self, data): | |||
for t in self.transforms: | |||
if self.profiling: | |||
start = time.time() | |||
data = t(data) | |||
if self.profiling: | |||
print(f'{t} time {time.time()-start}') | |||
if data is None: | |||
return None | |||
return data | |||
def __repr__(self): | |||
format_string = self.__class__.__name__ + '(' | |||
for t in self.transforms: | |||
format_string += f'\n {t}' | |||
format_string += '\n)' | |||
return format_string |
@@ -0,0 +1,70 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import io | |||
from typing import Dict, Union | |||
from PIL import Image, ImageOps | |||
from maas_lib.fileio import File | |||
from maas_lib.utils.constant import Fields | |||
from .builder import PREPROCESSORS | |||
@PREPROCESSORS.register_module(Fields.image) | |||
class LoadImage: | |||
"""Load an image from file or url. | |||
Added or updated keys are "filename", "img", "img_shape", | |||
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||
Args: | |||
mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`. | |||
to_float32 (bool): Whether to convert the loaded image to a float32 | |||
numpy array. If set to False, the loaded image is an uint8 array. | |||
Defaults to False. | |||
""" | |||
def __init__(self, mode='rgb'): | |||
self.mode = mode.upper() | |||
def __call__(self, input: Union[str, Dict[str, str]]): | |||
"""Call functions to load image and get image meta information. | |||
Args: | |||
input (str or dict): input image path or input dict with | |||
a key `filename`. | |||
Returns: | |||
dict: The dict contains loaded image. | |||
""" | |||
if isinstance(input, dict): | |||
image_path_or_url = input['filename'] | |||
else: | |||
image_path_or_url = input | |||
bytes = File.read(image_path_or_url) | |||
# TODO @wenmeng.zwm add opencv decode as optional | |||
# we should also look at the input format which is the most commonly | |||
# used in Mind' image related models | |||
with io.BytesIO(bytes) as infile: | |||
img = Image.open(infile) | |||
img = ImageOps.exif_transpose(img) | |||
img = img.convert(self.mode) | |||
results = { | |||
'filename': image_path_or_url, | |||
'img': img, | |||
'img_shape': (img.size[1], img.size[0], 3), | |||
'img_field': 'img', | |||
} | |||
return results | |||
def __repr__(self): | |||
repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})') | |||
return repr_str | |||
def load_image(image_path_or_url: str) -> Image: | |||
""" simple interface to load an image from file or url | |||
Args: | |||
image_path_or_url (str): image file path or http url | |||
""" | |||
loader = LoadImage() | |||
return loader(image_path_or_url)['img'] |
@@ -0,0 +1,91 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import uuid | |||
from typing import Any, Dict, Union | |||
from transformers import AutoTokenizer | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.type_assert import type_assert | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||
@PREPROCESSORS.register_module(Fields.nlp) | |||
class Tokenize(Preprocessor): | |||
def __init__(self, tokenizer_name) -> None: | |||
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||
def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: | |||
if isinstance(data, str): | |||
data = {InputFields.text: data} | |||
token_dict = self._tokenizer(data[InputFields.text]) | |||
data.update(token_dict) | |||
return data | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=r'bert-sentiment-analysis') | |||
class SequenceClassificationPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
from easynlp.modelzoo import AutoTokenizer | |||
self.model_dir: str = model_dir | |||
self.first_sequence: str = kwargs.pop('first_sequence', | |||
'first_sequence') | |||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
self.sequence_length = kwargs.pop('sequence_length', 128) | |||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data (str): a sentence | |||
Example: | |||
'you are so handsome.' | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
new_data = {self.first_sequence: data} | |||
# preprocess the data for the model input | |||
rst = { | |||
'id': [], | |||
'input_ids': [], | |||
'attention_mask': [], | |||
'token_type_ids': [] | |||
} | |||
max_seq_length = self.sequence_length | |||
text_a = new_data[self.first_sequence] | |||
text_b = new_data.get(self.second_sequence, None) | |||
feature = self.tokenizer( | |||
text_a, | |||
text_b, | |||
padding='max_length', | |||
truncation=True, | |||
max_length=max_seq_length) | |||
rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
rst['input_ids'].append(feature['input_ids']) | |||
rst['attention_mask'].append(feature['attention_mask']) | |||
rst['token_type_ids'].append(feature['token_type_ids']) | |||
return rst |
@@ -6,6 +6,7 @@ class Fields(object): | |||
""" | |||
image = 'image' | |||
video = 'video' | |||
cv = 'cv' | |||
nlp = 'nlp' | |||
audio = 'audio' | |||
multi_modal = 'multi_modal' | |||
@@ -18,12 +19,41 @@ class Tasks(object): | |||
This should be used to register models, pipelines, trainers. | |||
""" | |||
# vision tasks | |||
image_to_text = 'image-to-text' | |||
pose_estimation = 'pose-estimation' | |||
image_classfication = 'image-classification' | |||
image_tagging = 'image-tagging' | |||
object_detection = 'object-detection' | |||
image_segmentation = 'image-segmentation' | |||
image_editing = 'image-editing' | |||
image_generation = 'image-generation' | |||
image_matting = 'image-matting' | |||
# nlp tasks | |||
sentiment_analysis = 'sentiment-analysis' | |||
fill_mask = 'fill-mask' | |||
text_classification = 'text-classification' | |||
relation_extraction = 'relation-extraction' | |||
zero_shot = 'zero-shot' | |||
translation = 'translation' | |||
token_classificatio = 'token-classification' | |||
conversational = 'conversational' | |||
text_generation = 'text-generation' | |||
table_question_answ = 'table-question-answering' | |||
feature_extraction = 'feature-extraction' | |||
sentence_similarity = 'sentence-similarity' | |||
fill_mask = 'fill-mask ' | |||
summarization = 'summarization' | |||
question_answering = 'question-answering' | |||
# audio tasks | |||
auto_speech_recognition = 'auto-speech-recognition' | |||
text_to_speech = 'text-to-speech' | |||
speech_signal_process = 'speech-signal-process' | |||
# multi-media | |||
image_captioning = 'image-captioning' | |||
visual_grounding = 'visual-grounding' | |||
text_to_image_synthesis = 'text-to-image-synthesis' | |||
class InputFields(object): | |||
@@ -1,5 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import inspect | |||
from email.policy import default | |||
from maas_lib.utils.logger import get_logger | |||
@@ -15,10 +17,10 @@ class Registry(object): | |||
def __init__(self, name: str): | |||
self._name = name | |||
self._modules = dict() | |||
self._modules = {default_group: {}} | |||
def __repr__(self): | |||
format_str = self.__class__.__name__ + f'({self._name})\n' | |||
format_str = self.__class__.__name__ + f' ({self._name})\n' | |||
for group_name, group in self._modules.items(): | |||
format_str += f'group_name={group_name}, '\ | |||
f'modules={list(group.keys())}\n' | |||
@@ -64,11 +66,24 @@ class Registry(object): | |||
module_name = module_cls.__name__ | |||
if module_name in self._modules[group_key]: | |||
raise KeyError(f'{module_name} is already registered in' | |||
raise KeyError(f'{module_name} is already registered in ' | |||
f'{self._name}[{group_key}]') | |||
self._modules[group_key][module_name] = module_cls | |||
if module_name in self._modules[default_group]: | |||
if id(self._modules[default_group][module_name]) == id(module_cls): | |||
return | |||
else: | |||
logger.warning(f'{module_name} is already registered in ' | |||
f'{self._name}[{default_group}] and will ' | |||
'be overwritten') | |||
logger.warning(f'{self._modules[default_group][module_name]}' | |||
'to {module_cls}') | |||
# also register module in the default group for faster access | |||
# only by module name | |||
self._modules[default_group][module_name] = module_cls | |||
def register_module(self, | |||
group_key: str = default_group, | |||
module_name: str = None, | |||
@@ -165,12 +180,15 @@ def build_from_cfg(cfg, | |||
for name, value in default_args.items(): | |||
args.setdefault(name, value) | |||
if group_key is None: | |||
group_key = default_group | |||
obj_type = args.pop('type') | |||
if isinstance(obj_type, str): | |||
obj_cls = registry.get(obj_type, group_key=group_key) | |||
if obj_cls is None: | |||
raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
f'registry group {group_key}') | |||
f' registry group {group_key}') | |||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
obj_cls = obj_type | |||
else: | |||
@@ -0,0 +1,50 @@ | |||
from functools import wraps | |||
from inspect import signature | |||
def type_assert(*ty_args, **ty_kwargs): | |||
"""a decorator which is used to check the types of arguments in a function or class | |||
Examples: | |||
>>> @type_assert(str) | |||
... def main(a: str, b: list): | |||
... print(a, b) | |||
>>> main(1) | |||
Argument a must be a str | |||
>>> @type_assert(str, (int, str)) | |||
... def main(a: str, b: int | str): | |||
... print(a, b) | |||
>>> main('1', [1]) | |||
Argument b must be (<class 'int'>, <class 'str'>) | |||
>>> @type_assert(str, (int, str)) | |||
... class A: | |||
... def __init__(self, a: str, b: int | str) | |||
... print(a, b) | |||
>>> a = A('1', [1]) | |||
Argument b must be (<class 'int'>, <class 'str'>) | |||
""" | |||
def decorate(func): | |||
# If in optimized mode, disable type checking | |||
if not __debug__: | |||
return func | |||
# Map function argument names to supplied types | |||
sig = signature(func) | |||
bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments | |||
@wraps(func) | |||
def wrapper(*args, **kwargs): | |||
bound_values = sig.bind(*args, **kwargs) | |||
# Enforce type assertions across supplied arguments | |||
for name, value in bound_values.arguments.items(): | |||
if name in bound_types: | |||
if not isinstance(value, bound_types[name]): | |||
raise TypeError('Argument {} must be {}'.format( | |||
name, bound_types[name])) | |||
return func(*args, **kwargs) | |||
return wrapper | |||
return decorate |
@@ -1 +1,2 @@ | |||
-r requirements/runtime.txt | |||
-r requirements/pipeline.txt |
@@ -0,0 +1,5 @@ | |||
http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||
tensorflow | |||
torch==1.9.1 | |||
torchaudio==0.9.1 | |||
torchvision==0.10.1 |
@@ -1,5 +1,8 @@ | |||
addict | |||
numpy | |||
opencv-python-headless | |||
Pillow | |||
pyyaml | |||
requests | |||
transformers | |||
yapf |
@@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||
[flake8] | |||
select = B,C,E,F,P,T4,W,B9 | |||
max-line-length = 120 | |||
ignore = F401 | |||
ignore = F401,F821 | |||
exclude = docs/src,*.pyi,.git |
@@ -0,0 +1,98 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from typing import Any, Dict, List, Tuple, Union | |||
import numpy as np | |||
import PIL | |||
from maas_lib.pipelines import Pipeline, pipeline | |||
from maas_lib.pipelines.builder import PIPELINES | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.logger import get_logger | |||
from maas_lib.utils.registry import default_group | |||
logger = get_logger() | |||
Input = Union[str, 'PIL.Image', 'numpy.ndarray'] | |||
class CustomPipelineTest(unittest.TestCase): | |||
def test_abstract(self): | |||
@PIPELINES.register_module() | |||
class CustomPipeline1(Pipeline): | |||
def __init__(self, | |||
config_file: str = None, | |||
model=None, | |||
preprocessor=None, | |||
**kwargs): | |||
super().__init__(config_file, model, preprocessor, **kwargs) | |||
with self.assertRaises(TypeError): | |||
CustomPipeline1() | |||
def test_custom(self): | |||
@PIPELINES.register_module( | |||
group_key=Tasks.image_tagging, module_name='custom-image') | |||
class CustomImagePipeline(Pipeline): | |||
def __init__(self, | |||
config_file: str = None, | |||
model=None, | |||
preprocessor=None, | |||
**kwargs): | |||
super().__init__(config_file, model, preprocessor, **kwargs) | |||
def preprocess(self, input: Union[str, | |||
'PIL.Image']) -> Dict[str, Any]: | |||
""" Provide default implementation based on preprocess_cfg and user can reimplement it | |||
""" | |||
if not isinstance(input, PIL.Image.Image): | |||
from maas_lib.preprocessors import load_image | |||
data_dict = {'img': load_image(input), 'url': input} | |||
else: | |||
data_dict = {'img': input} | |||
return data_dict | |||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
""" Provide default implementation using self.model and user can reimplement it | |||
""" | |||
outputs = {} | |||
if 'url' in inputs: | |||
outputs['filename'] = inputs['url'] | |||
img = inputs['img'] | |||
new_image = img.resize((img.width // 2, img.height // 2)) | |||
outputs['resize_image'] = np.array(new_image) | |||
outputs['dummy_result'] = 'dummy_result' | |||
return outputs | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs | |||
self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||
pipe = pipeline(pipeline_name='custom-image') | |||
pipe2 = pipeline(Tasks.image_tagging) | |||
self.assertTrue(type(pipe) is type(pipe2)) | |||
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||
'aliyuncs.com/data/test/images/image1.jpg' | |||
output = pipe(img_url) | |||
self.assertEqual(output['filename'], img_url) | |||
self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||
self.assertEqual(output['dummy_result'], 'dummy_result') | |||
outputs = pipe([img_url for i in range(4)]) | |||
self.assertEqual(len(outputs), 4) | |||
for out in outputs: | |||
self.assertEqual(out['filename'], img_url) | |||
self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||
self.assertEqual(out['dummy_result'], 'dummy_result') | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,32 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import tempfile | |||
import unittest | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
from maas_lib.fileio import File | |||
from maas_lib.pipelines import pipeline | |||
from maas_lib.utils.constant import Tasks | |||
class ImageMattingTest(unittest.TestCase): | |||
def test_run(self): | |||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
'.com/data/test/maas/image_matting/matting_person.pb' | |||
with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||
ofile.write(File.read(model_path)) | |||
img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||
result = img_matting( | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
) | |||
cv2.imwrite('result.png', result['output_png']) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,48 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import tempfile | |||
import unittest | |||
import zipfile | |||
from maas_lib.fileio import File | |||
from maas_lib.models.nlp import SequenceClassificationModel | |||
from maas_lib.pipelines import SequenceClassificationPipeline | |||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
class SequenceClassificationTest(unittest.TestCase): | |||
def predict(self, pipeline: SequenceClassificationPipeline): | |||
from easynlp.appzoo import load_dataset | |||
set = load_dataset('glue', 'sst2') | |||
data = set['test']['sentence'][:3] | |||
results = pipeline(data[0]) | |||
print(results) | |||
results = pipeline(data[1]) | |||
print(results) | |||
print(data) | |||
def test_run(self): | |||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
with tempfile.TemporaryDirectory() as tmp_dir: | |||
tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||
with open(tmp_file, 'wb') as ofile: | |||
ofile.write(File.read(model_url)) | |||
with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||
zipf.extractall(tmp_dir) | |||
path = osp.join(tmp_dir, 'bert-base-sst2') | |||
print(path) | |||
model = SequenceClassificationModel(path) | |||
preprocessor = SequenceClassificationPreprocessor( | |||
path, first_sequence='sentence', second_sequence=None) | |||
pipeline = SequenceClassificationPipeline(model, preprocessor) | |||
self.predict(pipeline) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,39 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor | |||
class ComposeTest(unittest.TestCase): | |||
def test_compose(self): | |||
@PREPROCESSORS.register_module() | |||
class Tmp1(Preprocessor): | |||
def __call__(self, input): | |||
input['tmp1'] = 'tmp1' | |||
return input | |||
@PREPROCESSORS.register_module() | |||
class Tmp2(Preprocessor): | |||
def __call__(self, input): | |||
input['tmp2'] = 'tmp2' | |||
return input | |||
pipeline = [ | |||
dict(type='Tmp1'), | |||
dict(type='Tmp2'), | |||
] | |||
trans = Compose(pipeline) | |||
input = {} | |||
output = trans(input) | |||
self.assertEqual(output['tmp1'], 'tmp1') | |||
self.assertEqual(output['tmp2'], 'tmp2') | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,37 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from maas_lib.preprocessors import build_preprocessor | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.logger import get_logger | |||
logger = get_logger() | |||
class NLPPreprocessorTest(unittest.TestCase): | |||
def test_tokenize(self): | |||
cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') | |||
preprocessor = build_preprocessor(cfg, Fields.nlp) | |||
input = { | |||
InputFields.text: | |||
'Do not meddle in the affairs of wizards, ' | |||
'for they are subtle and quick to anger.' | |||
} | |||
output = preprocessor(input) | |||
self.assertTrue(InputFields.text in output) | |||
self.assertEqual(output['input_ids'], [ | |||
101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, | |||
117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 | |||
]) | |||
self.assertEqual( | |||
output['token_type_ids'], | |||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |||
self.assertEqual( | |||
output['attention_mask'], | |||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase): | |||
def test_register_class_no_task(self): | |||
MODELS = Registry('models') | |||
self.assertTrue(MODELS.name == 'models') | |||
self.assertTrue(MODELS.modules == {}) | |||
self.assertEqual(len(MODELS.modules), 0) | |||
self.assertTrue(default_group in MODELS.modules) | |||
self.assertTrue(MODELS.modules[default_group] == {}) | |||
self.assertEqual(len(MODELS.modules), 1) | |||
@MODELS.register_module(module_name='cls-resnet') | |||
class ResNetForCls(object): | |||
@@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase): | |||
self.assertTrue(Tasks.object_detection in MODELS.modules) | |||
self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | |||
self.assertEqual(len(MODELS.modules), 3) | |||
self.assertEqual(len(MODELS.modules), 4) | |||
def test_list(self): | |||
MODELS = Registry('models') | |||
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from typing import List, Union | |||
from maas_lib.utils.type_assert import type_assert | |||
class type_assertTest(unittest.TestCase): | |||
@type_assert(object, list, (int, str)) | |||
def a(self, a: List[int], b: Union[int, str]): | |||
print(a, b) | |||
def test_type_assert(self): | |||
with self.assertRaises(TypeError): | |||
self.a([1], 2) | |||
self.a(1, [123]) | |||
if __name__ == '__main__': | |||
unittest.main() |