Browse Source

[to #42322933]中文分词

chinese word segmentation
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9051491

    * add word segmentation

* Merge branch 'master' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib

* test with model hub

* merge with master

* update some description and test levels

* adding purge logic in test

* merge with master

* update variables definition

* generic word segmentation model as token classification model

* add output check
master
zhangzhicheng.zzc huangjun.hj 3 years ago
parent
commit
eb3209a79a
9 changed files with 258 additions and 1 deletions
  1. +1
    -0
      modelscope/models/nlp/__init__.py
  2. +57
    -0
      modelscope/models/nlp/token_classification_model.py
  3. +3
    -0
      modelscope/pipelines/builder.py
  4. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  5. +71
    -0
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  6. +13
    -0
      modelscope/pipelines/outputs.py
  7. +49
    -1
      modelscope/preprocessors/nlp.py
  8. +1
    -0
      modelscope/utils/constant.py
  9. +62
    -0
      tests/pipelines/test_word_segmentation.py

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -1,3 +1,4 @@
from .sentence_similarity_model import * # noqa F403
from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403
from .token_classification_model import * # noqa F403

+ 57
- 0
modelscope/models/nlp/token_classification_model.py View File

@@ -0,0 +1,57 @@
import os
from typing import Any, Dict, Union

import numpy as np
import torch
from sofa import SbertConfig, SbertForTokenClassification

from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['StructBertForTokenClassification']


@MODELS.register_module(
Tasks.word_segmentation,
module_name=r'structbert-chinese-word-segmentation')
class StructBertForTokenClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the word segmentation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.model = SbertForTokenClassification.from_pretrained(
self.model_dir)
self.config = SbertConfig.from_pretrained(self.model_dir)

def forward(self, input: Dict[str,
Any]) -> Dict[str, Union[str, np.ndarray]]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, Union[str,np.ndarray]]: results
Example:
{
'predictions': array([1,4]), # lable 0-negative 1-positive
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
'text': str(今天),
}
"""
input_ids = torch.tensor(input['input_ids']).unsqueeze(0)
output = self.model(input_ids)
logits = output.logits
pred = torch.argmax(logits[0], dim=-1)
pred = pred.numpy()

rst = {'predictions': pred, 'logits': logits, 'text': input['text']}
return rst

+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines')

DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.word_segmentation:
('structbert-chinese-word-segmentation',
'damo/nlp_structbert_word-segmentation_chinese-base'),
Tasks.sentence_similarity:
('sbert-base-chinese-sentence-similarity',
'damo/nlp_structbert_sentence-similarity_chinese-base'),


+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,3 +1,4 @@
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403

+ 71
- 0
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -0,0 +1,71 @@
from typing import Any, Dict, Optional, Union

import numpy as np

from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

__all__ = ['WordSegmentationPipeline']


@PIPELINES.register_module(
Tasks.word_segmentation,
module_name=r'structbert-chinese-word-segmentation')
class WordSegmentationPipeline(Pipeline):

def __init__(self,
model: Union[StructBertForTokenClassification, str],
preprocessor: Optional[TokenClassifcationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction

Args:
model (StructBertForTokenClassification): a model instance
preprocessor (TokenClassifcationPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
StructBertForTokenClassification) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TokenClassifcationPreprocessor(model.model_dir)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = preprocessor.tokenizer
self.config = model.config
self.id2label = self.config.id2label

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

pred_list = inputs['predictions']
labels = []
for pre in pred_list:
labels.append(self.id2label[pre])
labels = labels[1:-1]
chunks = []
chunk = ''
assert len(inputs['text']) == len(labels)
for token, label in zip(inputs['text'], labels):
if label[0] == 'B' or label[0] == 'I':
chunk += token
else:
chunk += token
chunks.append(chunk)
chunk = ''
if chunk:
chunks.append(chunk)
seg_result = ' '.join(chunks)
rst = {
'output': seg_result,
}
return rst

+ 13
- 0
modelscope/pipelines/outputs.py View File

@@ -69,6 +69,19 @@ TASK_OUTPUTS = {
# }
Tasks.text_generation: ['text'],

# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"
# }
Tasks.word_segmentation: ['output'],

# sentence similarity result for single sample
# {
# "labels": "1",
# "scores": 0.9
# }
Tasks.sentence_similarity: ['scores', 'labels'],

# ============ audio tasks ===================

# ============ multi-modal tasks ===================


+ 49
- 1
modelscope/preprocessors/nlp.py View File

@@ -12,7 +12,7 @@ from .builder import PREPROCESSORS

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor'
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor'
]


@@ -171,3 +171,51 @@ class TextGenerationPreprocessor(Preprocessor):
rst['token_type_ids'].append(feature['token_type_ids'])

return {k: torch.tensor(v) for k, v in rst.items()}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-token-classification')
class TokenClassifcationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
# preprocess the data for the model input

text = data.replace(' ', '').strip()
tokens = []
for token in text:
token = self.tokenizer.tokenize(token)
tokens.extend(token)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_ids = self.tokenizer.build_inputs_with_special_tokens(input_ids)
attention_mask = [1] * len(input_ids)
token_type_ids = [0] * len(input_ids)
return {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids
}

+ 1
- 0
modelscope/utils/constant.py View File

@@ -30,6 +30,7 @@ class Tasks(object):
image_matting = 'image-matting'

# nlp tasks
word_segmentation = 'word-segmentation'
sentiment_analysis = 'sentiment-analysis'
sentence_similarity = 'sentence-similarity'
text_classification = 'text-classification'


+ 62
- 0
tests/pipelines/test_word_segmentation.py View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.models import Model
from modelscope.models.nlp import StructBertForTokenClassification
from modelscope.pipelines import WordSegmentationPipeline, pipeline
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


class WordSegmentationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_word-segmentation_chinese-base'
sentence = '今天天气不错,适合出去游玩'

def setUp(self) -> None:
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = TokenClassifcationPreprocessor(cache_path)
model = StructBertForTokenClassification(
cache_path, tokenizer=tokenizer)
pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.word_segmentation, model=model, preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = TokenClassifcationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=self.model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.word_segmentation)
print(pipeline_ins(input=self.sentence))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save