chinese word segmentation Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051491 * add word segmentation * Merge branch 'master' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib * test with model hub * merge with master * update some description and test levels * adding purge logic in test * merge with master * update variables definition * generic word segmentation model as token classification model * add output checkmaster
@@ -1,3 +1,4 @@ | |||
from .sentence_similarity_model import * # noqa F403 | |||
from .sequence_classification_model import * # noqa F403 | |||
from .text_generation_model import * # noqa F403 | |||
from .token_classification_model import * # noqa F403 |
@@ -0,0 +1,57 @@ | |||
import os | |||
from typing import Any, Dict, Union | |||
import numpy as np | |||
import torch | |||
from sofa import SbertConfig, SbertForTokenClassification | |||
from modelscope.utils.constant import Tasks | |||
from ..base import Model, Tensor | |||
from ..builder import MODELS | |||
__all__ = ['StructBertForTokenClassification'] | |||
@MODELS.register_module( | |||
Tasks.word_segmentation, | |||
module_name=r'structbert-chinese-word-segmentation') | |||
class StructBertForTokenClassification(Model): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the word segmentation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model_dir = model_dir | |||
self.model = SbertForTokenClassification.from_pretrained( | |||
self.model_dir) | |||
self.config = SbertConfig.from_pretrained(self.model_dir) | |||
def forward(self, input: Dict[str, | |||
Any]) -> Dict[str, Union[str, np.ndarray]]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
Returns: | |||
Dict[str, Union[str,np.ndarray]]: results | |||
Example: | |||
{ | |||
'predictions': array([1,4]), # lable 0-negative 1-positive | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
'text': str(今天), | |||
} | |||
""" | |||
input_ids = torch.tensor(input['input_ids']).unsqueeze(0) | |||
output = self.model(input_ids) | |||
logits = output.logits | |||
pred = torch.argmax(logits[0], dim=-1) | |||
pred = pred.numpy() | |||
rst = {'predictions': pred, 'logits': logits, 'text': input['text']} | |||
return rst |
@@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines') | |||
DEFAULT_MODEL_FOR_PIPELINE = { | |||
# TaskName: (pipeline_module_name, model_repo) | |||
Tasks.word_segmentation: | |||
('structbert-chinese-word-segmentation', | |||
'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
Tasks.sentence_similarity: | |||
('sbert-base-chinese-sentence-similarity', | |||
'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
@@ -1,3 +1,4 @@ | |||
from .sentence_similarity_pipeline import * # noqa F403 | |||
from .sequence_classification_pipeline import * # noqa F403 | |||
from .text_generation_pipeline import * # noqa F403 | |||
from .word_segmentation_pipeline import * # noqa F403 |
@@ -0,0 +1,71 @@ | |||
from typing import Any, Dict, Optional, Union | |||
import numpy as np | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import StructBertForTokenClassification | |||
from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from ..base import Pipeline, Tensor | |||
from ..builder import PIPELINES | |||
__all__ = ['WordSegmentationPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, | |||
module_name=r'structbert-chinese-word-segmentation') | |||
class WordSegmentationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[StructBertForTokenClassification, str], | |||
preprocessor: Optional[TokenClassifcationPreprocessor] = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | |||
Args: | |||
model (StructBertForTokenClassification): a model instance | |||
preprocessor (TokenClassifcationPreprocessor): a preprocessor instance | |||
""" | |||
model = model if isinstance( | |||
model, | |||
StructBertForTokenClassification) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = TokenClassifcationPreprocessor(model.model_dir) | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = preprocessor.tokenizer | |||
self.config = model.config | |||
self.id2label = self.config.id2label | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
pred_list = inputs['predictions'] | |||
labels = [] | |||
for pre in pred_list: | |||
labels.append(self.id2label[pre]) | |||
labels = labels[1:-1] | |||
chunks = [] | |||
chunk = '' | |||
assert len(inputs['text']) == len(labels) | |||
for token, label in zip(inputs['text'], labels): | |||
if label[0] == 'B' or label[0] == 'I': | |||
chunk += token | |||
else: | |||
chunk += token | |||
chunks.append(chunk) | |||
chunk = '' | |||
if chunk: | |||
chunks.append(chunk) | |||
seg_result = ' '.join(chunks) | |||
rst = { | |||
'output': seg_result, | |||
} | |||
return rst |
@@ -69,6 +69,19 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.text_generation: ['text'], | |||
# word segmentation result for single sample | |||
# { | |||
# "output": "今天 天气 不错 , 适合 出去 游玩" | |||
# } | |||
Tasks.word_segmentation: ['output'], | |||
# sentence similarity result for single sample | |||
# { | |||
# "labels": "1", | |||
# "scores": 0.9 | |||
# } | |||
Tasks.sentence_similarity: ['scores', 'labels'], | |||
# ============ audio tasks =================== | |||
# ============ multi-modal tasks =================== | |||
@@ -12,7 +12,7 @@ from .builder import PREPROCESSORS | |||
__all__ = [ | |||
'Tokenize', 'SequenceClassificationPreprocessor', | |||
'TextGenerationPreprocessor' | |||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor' | |||
] | |||
@@ -171,3 +171,51 @@ class TextGenerationPreprocessor(Preprocessor): | |||
rst['token_type_ids'].append(feature['token_type_ids']) | |||
return {k: torch.tensor(v) for k, v in rst.items()} | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=r'bert-token-classification') | |||
class TokenClassifcationPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
from sofa import SbertTokenizer | |||
self.model_dir: str = model_dir | |||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data (str): a sentence | |||
Example: | |||
'you are so handsome.' | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
# preprocess the data for the model input | |||
text = data.replace(' ', '').strip() | |||
tokens = [] | |||
for token in text: | |||
token = self.tokenizer.tokenize(token) | |||
tokens.extend(token) | |||
input_ids = self.tokenizer.convert_tokens_to_ids(tokens) | |||
input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids) | |||
attention_mask = [1] * len(input_ids) | |||
token_type_ids = [0] * len(input_ids) | |||
return { | |||
'text': text, | |||
'input_ids': input_ids, | |||
'attention_mask': attention_mask, | |||
'token_type_ids': token_type_ids | |||
} |
@@ -30,6 +30,7 @@ class Tasks(object): | |||
image_matting = 'image-matting' | |||
# nlp tasks | |||
word_segmentation = 'word-segmentation' | |||
sentiment_analysis = 'sentiment-analysis' | |||
sentence_similarity = 'sentence-similarity' | |||
text_classification = 'text-classification' | |||
@@ -0,0 +1,62 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import shutil | |||
import unittest | |||
from maas_hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import StructBertForTokenClassification | |||
from modelscope.pipelines import WordSegmentationPipeline, pipeline | |||
from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import get_model_cache_dir | |||
from modelscope.utils.test_utils import test_level | |||
class WordSegmentationTest(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base' | |||
sentence = '今天天气不错,适合出去游玩' | |||
def setUp(self) -> None: | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
get_model_cache_dir(self.model_id), ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
tokenizer = TokenClassifcationPreprocessor(cache_path) | |||
model = StructBertForTokenClassification( | |||
cache_path, tokenizer=tokenizer) | |||
pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) | |||
pipeline2 = pipeline( | |||
Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
print(f'sentence: {self.sentence}\n' | |||
f'pipeline1:{pipeline1(input=self.sentence)}') | |||
print() | |||
print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
tokenizer = TokenClassifcationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.word_segmentation, model=self.model_id) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.word_segmentation) | |||
print(pipeline_ins(input=self.sentence)) | |||
if __name__ == '__main__': | |||
unittest.main() |