* add trainer interface * add trainer script * add model init support for pipelineadd pipeline tutorial and fix bugs * add text classification evaluation to maas lib * add quickstart and prepare env doc * relax requirements for torch and sentencepiece * merge release/0.1 and fix conflict * modelhub support for model and pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8868339master
@@ -0,0 +1,131 @@ | |||
{ | |||
"framework": "pytorch", | |||
"task": "image_classification", | |||
"model": { | |||
"type": "Resnet50ForImageClassification", | |||
"pretrained": null, | |||
"backbone": { | |||
"type": "ResNet", | |||
"depth": 50, | |||
"out_indices": [ | |||
4 | |||
], | |||
"norm_cfg": { | |||
"type": "BN" | |||
} | |||
}, | |||
"head": { | |||
"type": "ClsHead", | |||
"with_avg_pool": true, | |||
"in_channels": 2048, | |||
"loss_config": { | |||
"type": "CrossEntropyLossWithLabelSmooth", | |||
"label_smooth": 0 | |||
}, | |||
"num_classes": 1000 | |||
} | |||
}, | |||
"dataset": { | |||
"train": { | |||
"type": "ClsDataset", | |||
"data_source": { | |||
"list_file": "data/imagenet_raw/meta/train_labeled.txt", | |||
"root": "data/imagenet_raw/train/", | |||
"type": "ClsSourceImageList" | |||
} | |||
}, | |||
"val": { | |||
"type": "ClsDataset", | |||
"data_source": { | |||
"list_file": "data/imagenet_raw/meta/val_labeled.txt", | |||
"root": "data/imagenet_raw/validation/", | |||
"type": "ClsSourceImageList" | |||
} | |||
} | |||
}, | |||
"preprocessor":{ | |||
"train": [ | |||
{ | |||
"type": "RandomResizedCrop", | |||
"size": 224 | |||
}, | |||
{ | |||
"type": "RandomHorizontalFlip" | |||
}, | |||
{ | |||
"type": "ToTensor" | |||
}, | |||
{ | |||
"type": "Normalize", | |||
"mean": [ | |||
0.485, | |||
0.456, | |||
0.406 | |||
], | |||
"std": [ | |||
0.229, | |||
0.224, | |||
0.225 | |||
] | |||
}, | |||
{ | |||
"type": "Collect", | |||
"keys": [ | |||
"img", | |||
"gt_labels" | |||
] | |||
} | |||
], | |||
"val": [ | |||
{ | |||
"type": "Resize", | |||
"size": 256 | |||
}, | |||
{ | |||
"type": "CenterCrop", | |||
"size": 224 | |||
}, | |||
{ | |||
"type": "ToTensor" | |||
}, | |||
{ | |||
"type": "Normalize", | |||
"mean": [ | |||
0.485, | |||
0.456, | |||
0.406 | |||
], | |||
"std": [ | |||
0.229, | |||
0.224, | |||
0.225 | |||
] | |||
}, | |||
{ | |||
"type": "Collect", | |||
"keys": [ | |||
"img", | |||
"gt_labels" | |||
] | |||
} | |||
] | |||
}, | |||
"train": { | |||
"batch_size": 32, | |||
"learning_rate": 0.00001, | |||
"lr_scheduler_type": "cosine", | |||
"num_epochs": 20 | |||
}, | |||
"evaluation": { | |||
"batch_size": 32, | |||
"metrics": ["accuracy", "precision", "recall"] | |||
} | |||
} |
@@ -0,0 +1,59 @@ | |||
# In current version, many arguments are not used in pipelines, so, | |||
# a tag `[being used]` will indicate which argument is being used | |||
version: v0.1 | |||
framework: pytorch | |||
task: text-classification | |||
model: | |||
path: bert-base-sst2 | |||
attention_probs_dropout_prob: 0.1 | |||
bos_token_id: 0 | |||
eos_token_id: 2 | |||
hidden_act: elu | |||
hidden_dropout_prob: 0.1 | |||
hidden_size: 768 | |||
initializer_range: 0.02 | |||
intermediate_size: 3072 | |||
layer_norm_eps: 1e-05 | |||
max_position_embeddings: 514 | |||
model_type: roberta | |||
num_attention_heads: 12 | |||
num_hidden_layers: 12 | |||
pad_token_id: 1 | |||
type_vocab_size: 1 | |||
vocab_size: 50265 | |||
num_classes: 5 | |||
col_index: &col_indexs | |||
text_col: 0 | |||
label_col: 1 | |||
dataset: | |||
train: | |||
<<: *col_indexs | |||
file: ~ | |||
valid: | |||
<<: *col_indexs | |||
file: glue/sst2 # [being used] | |||
test: | |||
<<: *col_indexs | |||
file: ~ | |||
preprocessor: | |||
type: Tokenize | |||
tokenizer_name: /workspace/bert-base-sst2 | |||
train: | |||
batch_size: 256 | |||
learning_rate: 0.00001 | |||
lr_scheduler_type: cosine | |||
num_steps: 100000 | |||
evaluation: # [being used] | |||
model_path: .cache/easynlp/bert-base-sst2 | |||
max_sequence_length: 128 | |||
batch_size: 32 | |||
metrics: | |||
- accuracy | |||
- f1 |
@@ -11,6 +11,7 @@ MaasLib doc | |||
:maxdepth: 2 | |||
:caption: USER GUIDE | |||
quick_start.md | |||
develop.md | |||
.. toctree:: | |||
@@ -0,0 +1,64 @@ | |||
# 快速开始 | |||
## 环境准备 | |||
方式一: whl包安装, 执行如下命令 | |||
```shell | |||
pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas_lib-0.1.0-py3-none-any.whl | |||
``` | |||
方式二: 源码环境指定, 适合本地开发调试使用,修改源码后可以直接执行 | |||
```shell | |||
git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git maaslib | |||
git fetch origin release/0.1 | |||
git checkout release/0.1 | |||
cd maaslib | |||
#安装依赖 | |||
pip install -r requirements.txt | |||
# 设置PYTHONPATH | |||
export PYTHONPATH=`pwd` | |||
``` | |||
备注: mac arm cpu暂时由于依赖包版本问题会导致requirements暂时无法安装,请使用mac intel cpu, linux cpu/gpu机器测试。 | |||
## 训练 | |||
to be done | |||
## 评估 | |||
to be done | |||
## 推理 | |||
to be done | |||
<!-- pipeline函数提供了简洁的推理接口,示例如下 | |||
注: 这里提供的接口是完成和modelhub打通后的接口,暂时不支持使用。pipeline使用示例请参考 [pipelien tutorial](tutorials/pipeline.md)给出的示例。 | |||
```python | |||
import cv2 | |||
from maas_lib.pipelines import pipeline | |||
# 根据任务名创建pipeline | |||
img_matting = pipeline('image-matting') | |||
# 根据任务和模型名创建pipeline | |||
img_matting = pipeline('image-matting', model='damo/image-matting-person') | |||
# 自定义模型和预处理创建pipeline | |||
model = Model.from_pretrained('damo/xxx') | |||
preprocessor = Preprocessor.from_pretrained(cfg) | |||
img_matting = pipeline('image-matting', model=model, preprocessor=preprocessor) | |||
# 推理 | |||
result = img_matting( | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
) | |||
# 保存结果图片 | |||
cv2.imwrite('result.png', result['output_png']) | |||
``` --> |
@@ -11,6 +11,9 @@ | |||
* 指定特定预处理、特定模型进行推理 | |||
* 不同场景推理任务示例 | |||
## 环境准备 | |||
详细步骤可以参考 [快速开始](../quick_start.md) | |||
## Pipeline基本用法 | |||
1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应Pipeline对象 | |||
@@ -21,7 +24,7 @@ | |||
```shell | |||
wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/matting_person.pb | |||
``` | |||
执行python命令 | |||
执行如下python代码 | |||
```python | |||
>>> from maas_lib.pipelines import pipeline | |||
>>> img_matting = pipeline(task='image-matting', model_path='matting_person.pb') | |||
@@ -36,7 +39,7 @@ | |||
pipeline对象也支持传入一个列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 | |||
```python | |||
results = img_matting( | |||
>>> results = img_matting( | |||
[ | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', | |||
@@ -46,8 +49,8 @@ | |||
如果pipeline对应有一些后处理参数,也支持通过调用时候传入. | |||
```python | |||
pipe = pipeline(task_name) | |||
result = pipe(input, post_process_args) | |||
>>> pipe = pipeline(task_name) | |||
>>> result = pipe(input, post_process_args) | |||
``` | |||
## 指定预处理、模型进行推理 | |||
@@ -1,4 +1,4 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from .base import Model | |||
from .builder import MODELS | |||
from .builder import MODELS, build_model |
@@ -1,15 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from typing import Dict, List, Tuple, Union | |||
from maas_hub.file_download import model_file_download | |||
from maas_hub.snapshot_download import snapshot_download | |||
from maas_lib.models.builder import build_model | |||
from maas_lib.utils.config import Config | |||
from maas_lib.utils.constant import CONFIGFILE | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
class Model(ABC): | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def __init__(self, model_dir, *args, **kwargs): | |||
self.model_dir = model_dir | |||
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
return self.post_process(self.forward(input)) | |||
@@ -26,4 +34,22 @@ class Model(ABC): | |||
@classmethod | |||
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||
raise NotImplementedError('from_pretrained has not been implemented') | |||
""" Instantiate a model from local directory or remote model repo | |||
""" | |||
if osp.exists(model_name_or_path): | |||
local_model_dir = model_name_or_path | |||
else: | |||
local_model_dir = snapshot_download(model_name_or_path) | |||
# else: | |||
# raise ValueError( | |||
# 'Remote model repo {model_name_or_path} does not exists') | |||
cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | |||
task_name = cfg.task | |||
model_cfg = cfg.model | |||
# TODO @wenmeng.zwm may should mannually initialize model after model building | |||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
model_cfg.type = model_cfg.model_type | |||
model_cfg.model_dir = local_model_dir | |||
return build_model(model_cfg, task_name) |
@@ -14,30 +14,21 @@ __all__ = ['SequenceClassificationModel'] | |||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
class SequenceClassificationModel(Model): | |||
def __init__(self, | |||
model_dir: str, | |||
model_cls: Optional[Any] = None, | |||
*args, | |||
**kwargs): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
# Predictor.__init__(self, *args, **kwargs) | |||
"""initialize the sequence classification model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None. | |||
""" | |||
super().__init__(model_dir, model_cls, *args, **kwargs) | |||
super().__init__(model_dir, *args, **kwargs) | |||
from easynlp.appzoo import SequenceClassification | |||
from easynlp.core.predictor import get_model_predictor | |||
self.model_dir = model_dir | |||
model_cls = SequenceClassification if not model_cls else model_cls | |||
self.model = get_model_predictor( | |||
model_dir=model_dir, | |||
model_cls=model_cls, | |||
model_dir=self.model_dir, | |||
model_cls=SequenceClassification, | |||
input_keys=[('input_ids', torch.LongTensor), | |||
('attention_mask', torch.LongTensor), | |||
('token_type_ids', torch.LongTensor)], | |||
@@ -59,4 +50,3 @@ class SequenceClassificationModel(Model): | |||
} | |||
""" | |||
return self.model.predict(input) | |||
... |
@@ -1,10 +1,17 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from multiprocessing.sharedctypes import Value | |||
from typing import Any, Dict, List, Tuple, Union | |||
from maas_hub.snapshot_download import snapshot_download | |||
from maas_lib.models import Model | |||
from maas_lib.preprocessors import Preprocessor | |||
from maas_lib.utils.config import Config | |||
from maas_lib.utils.constant import CONFIGFILE | |||
from .util import is_model_name | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||
@@ -17,10 +24,38 @@ class Pipeline(ABC): | |||
def __init__(self, | |||
config_file: str = None, | |||
model: Model = None, | |||
model: Union[Model, str] = None, | |||
preprocessor: Preprocessor = None, | |||
**kwargs): | |||
self.model = model | |||
""" Base class for pipeline. | |||
If config_file is provided, model and preprocessor will be | |||
instantiated from corresponding config. Otherwise model | |||
and preprocessor will be constructed separately. | |||
Args: | |||
config_file(str, optional): Filepath to configuration file. | |||
model: Model name or model object | |||
preprocessor: Preprocessor object | |||
""" | |||
if config_file is not None: | |||
self.cfg = Config.from_file(config_file) | |||
if isinstance(model, str): | |||
if not osp.exists(model): | |||
model = snapshot_download(model) | |||
if is_model_name(model): | |||
self.model = Model.from_pretrained(model) | |||
else: | |||
self.model = model | |||
elif isinstance(model, Model): | |||
self.model = model | |||
else: | |||
if model: | |||
raise ValueError( | |||
f'model type is either str or Model, but got type {type(model)}' | |||
) | |||
self.preprocessor = preprocessor | |||
def __call__(self, input: Union[Input, List[Input]], *args, | |||
@@ -1,12 +1,17 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
from typing import Union | |||
import json | |||
from maas_hub.file_download import model_file_download | |||
from maas_lib.models.base import Model | |||
from maas_lib.utils.config import ConfigDict | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.config import Config, ConfigDict | |||
from maas_lib.utils.constant import CONFIGFILE, Tasks | |||
from maas_lib.utils.registry import Registry, build_from_cfg | |||
from .base import Pipeline | |||
from .util import is_model_name | |||
PIPELINES = Registry('pipelines') | |||
@@ -57,23 +62,26 @@ def pipeline(task: str = None, | |||
>>> resnet = Model.from_pretrained('Resnet') | |||
>>> p = pipeline('image-classification', model=resnet) | |||
""" | |||
if task is not None and pipeline_name is None: | |||
if model is None or isinstance(model, Model): | |||
# get default pipeline for this task | |||
assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||
cfg = dict(type=pipeline_name, **kwargs) | |||
if model is not None: | |||
cfg['model'] = model | |||
if preprocessor is not None: | |||
cfg['preprocessor'] = preprocessor | |||
else: | |||
assert isinstance(model, str), \ | |||
f'model should be either str or Model, but got {type(model)}' | |||
# TODO @wenmeng.zwm determine pipeline_name according to task and model | |||
elif pipeline_name is not None: | |||
cfg = dict(type=pipeline_name) | |||
else: | |||
if task is None and pipeline_name is None: | |||
raise ValueError('task or pipeline_name is required') | |||
if pipeline_name is None: | |||
# get default pipeline for this task | |||
assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
pipeline_name = get_default_pipeline(task) | |||
cfg = ConfigDict(type=pipeline_name) | |||
if model: | |||
assert isinstance(model, (str, Model)), \ | |||
f'model should be either str or Model, but got {type(model)}' | |||
cfg.model = model | |||
if preprocessor is not None: | |||
cfg.preprocessor = preprocessor | |||
return build_pipeline(cfg, task_name=task) | |||
def get_default_pipeline(task): | |||
return list(PIPELINES.modules[task].keys())[0] |
@@ -1,3 +1,4 @@ | |||
import os.path as osp | |||
from typing import Any, Dict, List, Tuple, Union | |||
import cv2 | |||
@@ -23,8 +24,9 @@ logger = get_logger() | |||
Tasks.image_matting, module_name=Tasks.image_matting) | |||
class ImageMatting(Pipeline): | |||
def __init__(self, model_path: str): | |||
super().__init__() | |||
def __init__(self, model: str): | |||
super().__init__(model=model) | |||
model_path = osp.join(self.model, 'matting_person.pb') | |||
config = tf.ConfigProto(allow_soft_placement=True) | |||
config.gpu_options.allow_growth = True | |||
@@ -0,0 +1,29 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import json | |||
from maas_hub.file_download import model_file_download | |||
from maas_lib.utils.constant import CONFIGFILE | |||
def is_model_name(model): | |||
if osp.exists(model): | |||
if osp.exists(osp.join(model, CONFIGFILE)): | |||
return True | |||
else: | |||
return False | |||
else: | |||
# try: | |||
# cfg_file = model_file_download(model, CONFIGFILE) | |||
# except Exception: | |||
# cfg_file = None | |||
# TODO @wenmeng.zwm use exception instead of | |||
# following tricky logic | |||
cfg_file = model_file_download(model, CONFIGFILE) | |||
with open(cfg_file, 'r') as infile: | |||
cfg = json.load(infile) | |||
if 'Code' in cfg: | |||
return False | |||
else: | |||
return True |
@@ -0,0 +1,30 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import argparse | |||
from maas_lib.trainers import build_trainer | |||
def parse_args(): | |||
parser = argparse.ArgumentParser(description='evaluate a model') | |||
parser.add_argument('config', help='config file path', type=str) | |||
parser.add_argument( | |||
'--trainer_name', help='name for trainer', type=str, default=None) | |||
parser.add_argument( | |||
'--checkpoint_path', | |||
help='checkpoint to be evaluated', | |||
type=str, | |||
default=None) | |||
args = parser.parse_args() | |||
return args | |||
def main(): | |||
args = parse_args() | |||
kwargs = dict(cfg_file=args.config) | |||
trainer = build_trainer(args.trainer_name, kwargs) | |||
trainer.evaluate(args.checkpoint_path) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,25 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import argparse | |||
from maas_lib.trainers import build_trainer | |||
def parse_args(): | |||
parser = argparse.ArgumentParser(description='Train a model') | |||
parser.add_argument('config', help='config file path', type=str) | |||
parser.add_argument( | |||
'trainer_name', help='name for trainer', type=str, default=None) | |||
args = parser.parse_args() | |||
return args | |||
def main(): | |||
args = parse_args() | |||
kwargs = dict(cfg_file=args.config) | |||
trainer = build_trainer(args.trainer_name, kwargs) | |||
trainer.train() | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,3 @@ | |||
from .base import DummyTrainer | |||
from .builder import build_trainer | |||
from .nlp import SequenceClassificationTrainer |
@@ -0,0 +1,86 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from abc import ABC, abstractmethod | |||
from typing import Callable, Dict, List, Optional, Tuple, Union | |||
from maas_lib.trainers.builder import TRAINERS | |||
from maas_lib.utils.config import Config | |||
class BaseTrainer(ABC): | |||
""" Base class for trainer which can not be instantiated. | |||
BaseTrainer defines necessary interface | |||
and provide default implementation for basic initialization | |||
such as parsing config file and parsing commandline args. | |||
""" | |||
def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None): | |||
""" Trainer basic init, should be called in derived class | |||
Args: | |||
cfg_file: Path to configuration file. | |||
arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`. | |||
""" | |||
self.cfg = Config.from_file(cfg_file) | |||
if arg_parse_fn: | |||
self.args = self.cfg.to_args(arg_parse_fn) | |||
else: | |||
self.args = None | |||
@abstractmethod | |||
def train(self, *args, **kwargs): | |||
""" Train (and evaluate) process | |||
Train process should be implemented for specific task or | |||
model, releated paramters have been intialized in | |||
``BaseTrainer.__init__`` and should be used in this function | |||
""" | |||
pass | |||
@abstractmethod | |||
def evaluate(self, checkpoint_path: str, *args, | |||
**kwargs) -> Dict[str, float]: | |||
""" Evaluation process | |||
Evaluation process should be implemented for specific task or | |||
model, releated paramters have been intialized in | |||
``BaseTrainer.__init__`` and should be used in this function | |||
""" | |||
pass | |||
@TRAINERS.register_module(module_name='dummy') | |||
class DummyTrainer(BaseTrainer): | |||
def __init__(self, cfg_file: str, *args, **kwargs): | |||
""" Dummy Trainer. | |||
Args: | |||
cfg_file: Path to configuration file. | |||
""" | |||
super().__init__(cfg_file) | |||
def train(self, *args, **kwargs): | |||
""" Train (and evaluate) process | |||
Train process should be implemented for specific task or | |||
model, releated paramters have been intialized in | |||
``BaseTrainer.__init__`` and should be used in this function | |||
""" | |||
cfg = self.cfg.train | |||
print(f'train cfg {cfg}') | |||
def evaluate(self, | |||
checkpoint_path: str = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
""" Evaluation process | |||
Evaluation process should be implemented for specific task or | |||
model, releated paramters have been intialized in | |||
``BaseTrainer.__init__`` and should be used in this function | |||
""" | |||
cfg = self.cfg.evaluation | |||
print(f'eval cfg {cfg}') | |||
print(f'checkpoint_path {checkpoint_path}') |
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from maas_lib.utils.config import ConfigDict | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.registry import Registry, build_from_cfg | |||
TRAINERS = Registry('trainers') | |||
def build_trainer(name: str = None, default_args: dict = None): | |||
""" build trainer given a trainer name | |||
Args: | |||
name (str, optional): Trainer name, if None, default trainer | |||
will be used. | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
if name is None: | |||
name = 'Trainer' | |||
cfg = dict(type=name) | |||
return build_from_cfg(cfg, TRAINERS, default_args=default_args) |
@@ -0,0 +1 @@ | |||
from .sequence_classification_trainer import SequenceClassificationTrainer |
@@ -0,0 +1,226 @@ | |||
import time | |||
from typing import Callable, Dict, List, Optional, Tuple, Union | |||
import numpy as np | |||
from maas_lib.utils.constant import Tasks | |||
from maas_lib.utils.logger import get_logger | |||
from ..base import BaseTrainer | |||
from ..builder import TRAINERS | |||
# __all__ = ["SequenceClassificationTrainer"] | |||
PATH = None | |||
logger = get_logger(PATH) | |||
@TRAINERS.register_module( | |||
Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
class SequenceClassificationTrainer(BaseTrainer): | |||
def __init__(self, cfg_file: str, *args, **kwargs): | |||
""" A trainer is used for Sequence Classification | |||
Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset | |||
Args: | |||
cfg_file (str): the path of config file | |||
Raises: | |||
ValueError: _description_ | |||
""" | |||
super().__init__(cfg_file) | |||
def train(self, *args, **kwargs): | |||
logger.info('Train') | |||
... | |||
def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]: | |||
"""get attribute from config, if the attribute does exist, return false | |||
Example: | |||
>>> self.__attr_is_exist("model path") | |||
out: (model-path, "/workspace/bert-base-sst2") | |||
>>> self.__attr_is_exist("model weights") | |||
out: (model-weights, False) | |||
Args: | |||
attr (str): attribute str, "model path" -> config["model"][path] | |||
Returns: | |||
Tuple[Union[str, bool]]:[target attribute name, the target attribute or False] | |||
""" | |||
paths = attr.split(' ') | |||
attr_str: str = '-'.join(paths) | |||
target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None | |||
for path_ in paths[1:]: | |||
if not hasattr(target, path_): | |||
return attr_str, False | |||
target = target[path_] | |||
if target and target != '': | |||
return attr_str, target | |||
return attr_str, False | |||
def evaluate(self, | |||
checkpoint_path: Optional[str] = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
"""evaluate a dataset | |||
evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` | |||
does not exist, read from the config file. | |||
Args: | |||
checkpoint_path (Optional[str], optional): the model path. Defaults to None. | |||
Returns: | |||
Dict[str, float]: the results about the evaluation | |||
Example: | |||
{"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||
""" | |||
import torch | |||
from easynlp.appzoo import load_dataset | |||
from easynlp.appzoo.dataset import GeneralDataset | |||
from easynlp.appzoo.sequence_classification.model import SequenceClassification | |||
from easynlp.utils import losses | |||
from sklearn.metrics import f1_score | |||
from torch.utils.data import DataLoader | |||
raise_str = 'Attribute {} is not given in config file!' | |||
metrics = self.__attr_is_exist('evaluation metrics') | |||
eval_batch_size = self.__attr_is_exist('evaluation batch_size') | |||
test_dataset_path = self.__attr_is_exist('dataset valid file') | |||
attrs = [metrics, eval_batch_size, test_dataset_path] | |||
for attr_ in attrs: | |||
if not attr_[-1]: | |||
raise AttributeError(raise_str.format(attr_[0])) | |||
if not checkpoint_path: | |||
checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1] | |||
if not checkpoint_path: | |||
raise ValueError( | |||
'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!' | |||
) | |||
max_sequence_length = kwargs.get( | |||
'max_sequence_length', | |||
self.__attr_is_exist('evaluation max_sequence_length')[-1]) | |||
if not max_sequence_length: | |||
raise ValueError( | |||
'Argument max_sequence_length must be passed ' | |||
'if the evaluation-max_sequence_length does not exist in config file!' | |||
) | |||
# get the raw online dataset | |||
raw_dataset = load_dataset(*test_dataset_path[-1].split('/')) | |||
valid_dataset = raw_dataset['validation'] | |||
# generate a standard dataloader | |||
pre_dataset = GeneralDataset(valid_dataset, checkpoint_path, | |||
max_sequence_length) | |||
valid_dataloader = DataLoader( | |||
pre_dataset, | |||
batch_size=eval_batch_size[-1], | |||
shuffle=False, | |||
collate_fn=pre_dataset.batch_fn) | |||
# generate a model | |||
model = SequenceClassification(checkpoint_path) | |||
# copy from easynlp (start) | |||
model.eval() | |||
total_loss = 0 | |||
total_steps = 0 | |||
total_samples = 0 | |||
hit_num = 0 | |||
total_num = 0 | |||
logits_list = list() | |||
y_trues = list() | |||
total_spent_time = 0.0 | |||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
model.to(device) | |||
for _step, batch in enumerate(valid_dataloader): | |||
try: | |||
batch = { | |||
# key: val.cuda() if isinstance(val, torch.Tensor) else val | |||
# for key, val in batch.items() | |||
key: | |||
val.to(device) if isinstance(val, torch.Tensor) else val | |||
for key, val in batch.items() | |||
} | |||
except RuntimeError: | |||
batch = {key: val for key, val in batch.items()} | |||
infer_start_time = time.time() | |||
with torch.no_grad(): | |||
label_ids = batch.pop('label_ids') | |||
outputs = model(batch) | |||
infer_end_time = time.time() | |||
total_spent_time += infer_end_time - infer_start_time | |||
assert 'logits' in outputs | |||
logits = outputs['logits'] | |||
y_trues.extend(label_ids.tolist()) | |||
logits_list.extend(logits.tolist()) | |||
hit_num += torch.sum( | |||
torch.argmax(logits, dim=-1) == label_ids).item() | |||
total_num += label_ids.shape[0] | |||
if len(logits.shape) == 1 or logits.shape[-1] == 1: | |||
tmp_loss = losses.mse_loss(logits, label_ids) | |||
elif len(logits.shape) == 2: | |||
tmp_loss = losses.cross_entropy(logits, label_ids) | |||
else: | |||
raise RuntimeError | |||
total_loss += tmp_loss.mean().item() | |||
total_steps += 1 | |||
total_samples += valid_dataloader.batch_size | |||
if (_step + 1) % 100 == 0: | |||
total_step = len( | |||
valid_dataloader.dataset) // valid_dataloader.batch_size | |||
logger.info('Eval: {}/{} steps finished'.format( | |||
_step + 1, total_step)) | |||
logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( | |||
total_spent_time, total_spent_time * 1000 / total_samples)) | |||
eval_loss = total_loss / total_steps | |||
logger.info('Eval loss: {}'.format(eval_loss)) | |||
logits_list = np.array(logits_list) | |||
eval_outputs = list() | |||
for metric in metrics[-1]: | |||
if metric.endswith('accuracy'): | |||
acc = hit_num / total_num | |||
logger.info('Accuracy: {}'.format(acc)) | |||
eval_outputs.append(('accuracy', acc)) | |||
elif metric == 'f1': | |||
if model.config.num_labels == 2: | |||
f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1)) | |||
logger.info('F1: {}'.format(f1)) | |||
eval_outputs.append(('f1', f1)) | |||
else: | |||
f1 = f1_score( | |||
y_trues, | |||
np.argmax(logits_list, axis=-1), | |||
average='macro') | |||
logger.info('Macro F1: {}'.format(f1)) | |||
eval_outputs.append(('macro-f1', f1)) | |||
f1 = f1_score( | |||
y_trues, | |||
np.argmax(logits_list, axis=-1), | |||
average='micro') | |||
logger.info('Micro F1: {}'.format(f1)) | |||
eval_outputs.append(('micro-f1', f1)) | |||
else: | |||
raise NotImplementedError('Metric %s not implemented' % metric) | |||
# copy from easynlp (end) | |||
return dict(eval_outputs) |
@@ -62,3 +62,9 @@ class InputFields(object): | |||
img = 'img' | |||
text = 'text' | |||
audio = 'audio' | |||
# configuration filename | |||
# in order to avoid conflict with huggingface | |||
# config file we use maas_config instead | |||
CONFIGFILE = 'maas_config.json' |
@@ -1,5 +1,6 @@ | |||
http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||
#https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl | |||
tensorflow | |||
torch==1.9.1 | |||
torchaudio==0.9.1 | |||
torchvision==0.10.1 | |||
#--find-links https://download.pytorch.org/whl/torch_stable.html | |||
torch<1.10,>=1.8.0 | |||
torchaudio | |||
torchvision |
@@ -1,4 +1,5 @@ | |||
addict | |||
https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | |||
numpy | |||
opencv-python-headless | |||
Pillow | |||
@@ -113,26 +113,39 @@ def parse_requirements(fname='requirements.txt', with_version=True): | |||
if line.startswith('http'): | |||
print('skip http requirements %s' % line) | |||
continue | |||
if line and not line.startswith('#'): | |||
if line and not line.startswith('#') and not line.startswith( | |||
'--'): | |||
for info in parse_line(line): | |||
yield info | |||
elif line and line.startswith('--find-links'): | |||
eles = line.split() | |||
for e in eles: | |||
e = e.strip() | |||
if 'http' in e: | |||
info = dict(dependency_links=e) | |||
yield info | |||
def gen_packages_items(): | |||
items = [] | |||
deps_link = [] | |||
if exists(require_fpath): | |||
for info in parse_require_file(require_fpath): | |||
parts = [info['package']] | |||
if with_version and 'version' in info: | |||
parts.extend(info['version']) | |||
if not sys.version.startswith('3.4'): | |||
# apparently package_deps are broken in 3.4 | |||
platform_deps = info.get('platform_deps') | |||
if platform_deps is not None: | |||
parts.append(';' + platform_deps) | |||
item = ''.join(parts) | |||
yield item | |||
packages = list(gen_packages_items()) | |||
return packages | |||
if 'dependency_links' not in info: | |||
parts = [info['package']] | |||
if with_version and 'version' in info: | |||
parts.extend(info['version']) | |||
if not sys.version.startswith('3.4'): | |||
# apparently package_deps are broken in 3.4 | |||
platform_deps = info.get('platform_deps') | |||
if platform_deps is not None: | |||
parts.append(';' + platform_deps) | |||
item = ''.join(parts) | |||
items.append(item) | |||
else: | |||
deps_link.append(info['dependency_links']) | |||
return items, deps_link | |||
return gen_packages_items() | |||
def pack_resource(): | |||
@@ -155,7 +168,7 @@ if __name__ == '__main__': | |||
# write_version_py() | |||
pack_resource() | |||
os.chdir('package') | |||
install_requires = parse_requirements('requirements.txt') | |||
install_requires, deps_link = parse_requirements('requirements.txt') | |||
setup( | |||
name='maas-lib', | |||
version=get_version(), | |||
@@ -180,4 +193,5 @@ if __name__ == '__main__': | |||
license='Apache License 2.0', | |||
tests_require=parse_requirements('requirements/tests.txt'), | |||
install_requires=install_requires, | |||
dependency_links=deps_link, | |||
zip_safe=False) |
@@ -1,5 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import tempfile | |||
import unittest | |||
from typing import Any, Dict, List, Tuple, Union | |||
@@ -18,15 +19,26 @@ class ImageMattingTest(unittest.TestCase): | |||
def test_run(self): | |||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
'.com/data/test/maas/image_matting/matting_person.pb' | |||
with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||
ofile.write(File.read(model_path)) | |||
img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||
with tempfile.TemporaryDirectory() as tmp_dir: | |||
model_file = osp.join(tmp_dir, 'matting_person.pb') | |||
with open(model_file, 'wb') as ofile: | |||
ofile.write(File.read(model_path)) | |||
img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | |||
result = img_matting( | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
) | |||
cv2.imwrite('result.png', result['output_png']) | |||
def test_run_modelhub(self): | |||
img_matting = pipeline( | |||
Tasks.image_matting, model='damo/image-matting-person') | |||
result = img_matting( | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
) | |||
cv2.imwrite('result.png', result['output_png']) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -1,11 +1,11 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import tempfile | |||
import unittest | |||
import zipfile | |||
from pathlib import Path | |||
from maas_lib.fileio import File | |||
from maas_lib.models import Model | |||
from maas_lib.models.nlp import SequenceClassificationModel | |||
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
@@ -13,15 +13,15 @@ from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
class SequenceClassificationTest(unittest.TestCase): | |||
def predict(self, pipeline: SequenceClassificationPipeline): | |||
def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
from easynlp.appzoo import load_dataset | |||
set = load_dataset('glue', 'sst2') | |||
data = set['test']['sentence'][:3] | |||
results = pipeline(data[0]) | |||
results = pipeline_ins(data[0]) | |||
print(results) | |||
results = pipeline(data[1]) | |||
results = pipeline_ins(data[1]) | |||
print(results) | |||
print(data) | |||
@@ -29,22 +29,34 @@ class SequenceClassificationTest(unittest.TestCase): | |||
def test_run(self): | |||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
with tempfile.TemporaryDirectory() as tmp_dir: | |||
tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||
with open(tmp_file, 'wb') as ofile: | |||
cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||
cache_path = Path(cache_path_str) | |||
if not cache_path.exists(): | |||
cache_path.parent.mkdir(parents=True, exist_ok=True) | |||
cache_path.touch(exist_ok=True) | |||
with cache_path.open('wb') as ofile: | |||
ofile.write(File.read(model_url)) | |||
with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||
zipf.extractall(tmp_dir) | |||
path = osp.join(tmp_dir, 'bert-base-sst2') | |||
print(path) | |||
model = SequenceClassificationModel(path) | |||
preprocessor = SequenceClassificationPreprocessor( | |||
path, first_sequence='sentence', second_sequence=None) | |||
pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||
self.predict(pipeline1) | |||
pipeline2 = pipeline( | |||
'text-classification', model=model, preprocessor=preprocessor) | |||
print(pipeline2('Hello world!')) | |||
with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
zipf.extractall(cache_path.parent) | |||
path = r'.cache/easynlp/bert-base-sst2' | |||
model = SequenceClassificationModel(path) | |||
preprocessor = SequenceClassificationPreprocessor( | |||
path, first_sequence='sentence', second_sequence=None) | |||
pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||
self.predict(pipeline1) | |||
pipeline2 = pipeline( | |||
'text-classification', model=model, preprocessor=preprocessor) | |||
print(pipeline2('Hello world!')) | |||
def test_run_modelhub(self): | |||
model = Model.from_pretrained('damo/bert-base-sst2') | |||
preprocessor = SequenceClassificationPreprocessor( | |||
model.model_dir, first_sequence='sentence', second_sequence=None) | |||
pipeline_ins = pipeline( | |||
task='text-classification', model=model, preprocessor=preprocessor) | |||
self.predict(pipeline_ins) | |||
if __name__ == '__main__': | |||
@@ -0,0 +1,38 @@ | |||
import unittest | |||
import zipfile | |||
from pathlib import Path | |||
from maas_lib.fileio import File | |||
from maas_lib.trainers import build_trainer | |||
from maas_lib.utils.logger import get_logger | |||
logger = get_logger() | |||
class SequenceClassificationTrainerTest(unittest.TestCase): | |||
def test_sequence_classification(self): | |||
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||
cache_path = Path(cache_path_str) | |||
if not cache_path.exists(): | |||
cache_path.parent.mkdir(parents=True, exist_ok=True) | |||
cache_path.touch(exist_ok=True) | |||
with cache_path.open('wb') as ofile: | |||
ofile.write(File.read(model_url)) | |||
with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
zipf.extractall(cache_path.parent) | |||
path: str = './configs/nlp/sequence_classification_trainer.yaml' | |||
default_args = dict(cfg_file=path) | |||
trainer = build_trainer('bert-sentiment-analysis', default_args) | |||
trainer.train() | |||
trainer.evaluate() | |||
if __name__ == '__main__': | |||
unittest.main() | |||
... |
@@ -0,0 +1,19 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from maas_lib.trainers import build_trainer | |||
class DummyTrainerTest(unittest.TestCase): | |||
def test_dummy(self): | |||
default_args = dict(cfg_file='configs/examples/train.json') | |||
trainer = build_trainer('dummy', default_args) | |||
trainer.train() | |||
trainer.evaluate() | |||
if __name__ == '__main__': | |||
unittest.main() |