dingkun.ldk yingda.chen 3 years ago
parent
commit
7ed4015bdc
16 changed files with 475 additions and 75 deletions
  1. +7
    -0
      modelscope/metainfo.py
  2. +20
    -23
      modelscope/models/nlp/__init__.py
  3. +0
    -1
      modelscope/models/nlp/heads/sequence_classification_head.py
  4. +42
    -0
      modelscope/models/nlp/heads/token_classification_head.py
  5. +1
    -1
      modelscope/models/nlp/structbert/configuration_sbert.py
  6. +2
    -0
      modelscope/models/nlp/task_models/__init__.py
  7. +83
    -0
      modelscope/models/nlp/task_models/token_classification.py
  8. +1
    -0
      modelscope/models/nlp/token_classification.py
  9. +12
    -18
      modelscope/outputs.py
  10. +3
    -0
      modelscope/pipelines/builder.py
  11. +15
    -15
      modelscope/pipelines/nlp/__init__.py
  12. +92
    -0
      modelscope/pipelines/nlp/token_classification_pipeline.py
  13. +9
    -10
      modelscope/preprocessors/__init__.py
  14. +109
    -1
      modelscope/preprocessors/nlp.py
  15. +24
    -6
      modelscope/utils/hub.py
  16. +55
    -0
      tests/pipelines/test_part_of_speech.py

+ 7
- 0
modelscope/metainfo.py View File

@@ -55,7 +55,9 @@ class Models(object):
space_modeling = 'space-modeling'
star = 'star'
tcrf = 'transformer-crf'
transformer_softmax = 'transformer-softmax'
lcrf = 'lstm-crf'
gcnncrf = 'gcnn-crf'
bart = 'bart'
gpt3 = 'gpt3'
plug = 'plug'
@@ -82,6 +84,7 @@ class Models(object):
class TaskModels(object):
# nlp task
text_classification = 'text-classification'
token_classification = 'token-classification'
information_extraction = 'information-extraction'


@@ -92,6 +95,8 @@ class Heads(object):
bert_mlm = 'bert-mlm'
# roberta mlm
roberta_mlm = 'roberta-mlm'
# token cls
token_classification = 'token-classification'
information_extraction = 'information-extraction'


@@ -167,6 +172,7 @@ class Pipelines(object):
# nlp tasks
sentence_similarity = 'sentence-similarity'
word_segmentation = 'word-segmentation'
part_of_speech = 'part-of-speech'
named_entity_recognition = 'named-entity-recognition'
text_generation = 'text-generation'
sentiment_analysis = 'sentiment-analysis'
@@ -272,6 +278,7 @@ class Preprocessors(object):
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'
text_error_correction = 'text-error-correction'
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
fill_mask = 'fill-mask'
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'


+ 20
- 23
modelscope/models/nlp/__init__.py View File

@@ -5,40 +5,39 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .backbones import SbertModel
from .heads import SequenceClassificationHead
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .bert_for_sequence_classification import BertForSequenceClassification
from .bert_for_document_segmentation import BertForDocumentSegmentation
from .csanmt_for_translation import CsanmtForTranslation
from .masked_language import (
StructBertForMaskedLM,
VecoForMaskedLM,
BertForMaskedLM,
DebertaV2ForMaskedLM,
)
from .heads import SequenceClassificationHead
from .gpt3 import GPT3ForTextGeneration
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM,
BertForMaskedLM, DebertaV2ForMaskedLM)
from .nncrf_for_named_entity_recognition import (
TransformerCRFForNamedEntityRecognition,
LSTMCRFForNamedEntityRecognition)
from .token_classification import SbertForTokenClassification
from .palm_v2 import PalmForTextGeneration
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering
from .star_text_to_sql import StarForTextToSql
from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification
from .space import SpaceForDialogIntent
from .space import SpaceForDialogModeling
from .space import SpaceForDialogStateTracking
from .star_text_to_sql import StarForTextToSql
from .task_models import (InformationExtractionModel,
SingleBackboneTaskModelBase)
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .gpt3 import GPT3ForTextGeneration
from .plug import PlugForTextGeneration
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering
SequenceClassificationModel,
SingleBackboneTaskModelBase,
TokenClassificationModel)
from .token_classification import SbertForTokenClassification

else:
_import_structure = {
'star_text_to_sql': ['StarForTextToSql'],
'backbones': ['SbertModel'],
'heads': ['SequenceClassificationHead'],
'csanmt_for_translation': ['CsanmtForTranslation'],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
'bert_for_sequence_classification': ['BertForSequenceClassification'],
'bert_for_document_segmentation': ['BertForDocumentSegmentation'],
'csanmt_for_translation': ['CsanmtForTranslation'],
'heads': ['SequenceClassificationHead'],
'gpt3': ['GPT3ForTextGeneration'],
'masked_language': [
'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM',
'DebertaV2ForMaskedLM'
@@ -48,7 +47,8 @@ else:
'LSTMCRFForNamedEntityRecognition'
],
'palm_v2': ['PalmForTextGeneration'],
'token_classification': ['SbertForTokenClassification'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'],
'star_text_to_sql': ['StarForTextToSql'],
'sequence_classification':
['VecoForSequenceClassification', 'SbertForSequenceClassification'],
'space': [
@@ -57,12 +57,9 @@ else:
],
'task_models': [
'InformationExtractionModel', 'SequenceClassificationModel',
'SingleBackboneTaskModelBase'
'SingleBackboneTaskModelBase', 'TokenClassificationModel'
],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
'gpt3': ['GPT3ForTextGeneration'],
'plug': ['PlugForTextGeneration'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'],
'token_classification': ['SbertForTokenClassification'],
}

import sys


+ 0
- 1
modelscope/models/nlp/heads/sequence_classification_head.py View File

@@ -19,7 +19,6 @@ class SequenceClassificationHead(TorchHead):
super().__init__(**kwargs)
config = self.config
self.num_labels = config.num_labels
self.config = config
classifier_dropout = (
config['classifier_dropout'] if config.get('classifier_dropout')
is not None else config['hidden_dropout_prob'])


+ 42
- 0
modelscope/models/nlp/heads/token_classification_head.py View File

@@ -0,0 +1,42 @@
from typing import Dict

import torch
import torch.nn.functional as F
from torch import nn

from modelscope.metainfo import Heads
from modelscope.models.base import TorchHead
from modelscope.models.builder import HEADS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks


@HEADS.register_module(
Tasks.token_classification, module_name=Heads.token_classification)
class TokenClassificationHead(TorchHead):

def __init__(self, **kwargs):
super().__init__(**kwargs)
config = self.config
self.num_labels = config.num_labels
classifier_dropout = (
config['classifier_dropout'] if config.get('classifier_dropout')
is not None else config['hidden_dropout_prob'])
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config['hidden_size'],
config['num_labels'])

def forward(self, inputs=None):
if isinstance(inputs, dict):
assert inputs.get('sequence_output') is not None
sequence_output = inputs.get('sequence_output')
else:
sequence_output = inputs
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return {OutputKeys.LOGITS: logits}

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
logits = outputs[OutputKeys.LOGITS]
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)}

+ 1
- 1
modelscope/models/nlp/structbert/configuration_sbert.py View File

@@ -85,7 +85,7 @@ class SbertConfig(PretrainedConfig):
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor
"""

model_type = 'sbert'
model_type = 'structbert'

def __init__(self,
vocab_size=30522,


+ 2
- 0
modelscope/models/nlp/task_models/__init__.py View File

@@ -7,12 +7,14 @@ if TYPE_CHECKING:
from .information_extraction import InformationExtractionModel
from .sequence_classification import SequenceClassificationModel
from .task_model import SingleBackboneTaskModelBase
from .token_classification import TokenClassificationModel

else:
_import_structure = {
'information_extraction': ['InformationExtractionModel'],
'sequence_classification': ['SequenceClassificationModel'],
'task_model': ['SingleBackboneTaskModelBase'],
'token_classification': ['TokenClassificationModel'],
}

import sys


+ 83
- 0
modelscope/models/nlp/task_models/token_classification.py View File

@@ -0,0 +1,83 @@
from typing import Any, Dict

import numpy as np
import torch

from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)

__all__ = ['TokenClassificationModel']


@MODELS.register_module(
Tasks.token_classification, module_name=TaskModels.token_classification)
class TokenClassificationModel(SingleBackboneTaskModelBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the token classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

backbone_cfg = self.cfg.backbone
head_cfg = self.cfg.head

# get the num_labels
num_labels = kwargs.get('num_labels')
if num_labels is None:
label2id = parse_label_mapping(model_dir)
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)
self.id2label = {id: label for label, id in label2id.items()}
head_cfg['num_labels'] = num_labels

self.build_backbone(backbone_cfg)
self.build_head(head_cfg)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
labels = None
if OutputKeys.LABEL in input:
labels = input.pop(OutputKeys.LABEL)
elif OutputKeys.LABELS in input:
labels = input.pop(OutputKeys.LABELS)

outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(sequence_output)
if labels in input:
loss = self.compute_loss(outputs, labels)
outputs.update(loss)
return outputs

def extract_logits(self, outputs):
return outputs[OutputKeys.LOGITS].cpu().detach()

def extract_backbone_outputs(self, outputs):
sequence_output = None
pooled_output = None
if hasattr(self.backbone, 'extract_sequence_outputs'):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
return sequence_output, pooled_output

def compute_loss(self, outputs, labels):
loss = self.head.compute_loss(outputs, labels)
return loss

def postprocess(self, input, **kwargs):
logits = self.extract_logits(input)
pred = torch.argmax(logits[0], dim=-1)
pred = torch_nested_numpify(torch_nested_detach(pred))
logits = torch_nested_numpify(torch_nested_detach(logits))
res = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits}
return res

+ 1
- 0
modelscope/models/nlp/token_classification.py View File

@@ -91,6 +91,7 @@ class TokenClassification(TorchModel):


@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert)
@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert)
@MODELS.register_module(
Tasks.token_classification, module_name=Models.structbert)
class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):


+ 12
- 18
modelscope/outputs.py View File

@@ -359,26 +359,20 @@ TASK_OUTPUTS = {
# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"
# }
Tasks.word_segmentation: [OutputKeys.OUTPUT],

# part-of-speech result for single sample
# [
# {'word': '诸葛', 'label': 'PROPN'},
# {'word': '亮', 'label': 'PROPN'},
# {'word': '发明', 'label': 'VERB'},
# {'word': '八', 'label': 'NUM'},
# {'word': '阵', 'label': 'NOUN'},
# {'word': '图', 'label': 'PART'},
# {'word': '以', 'label': 'ADV'},
# {'word': '利', 'label': 'VERB'},
# {'word': '立营', 'label': 'VERB'},
# {'word': '练兵', 'label': 'VERB'},
# {'word': '.', 'label': 'PUNCT'}
# "labels": [
# {'word': '今天', 'label': 'PROPN'},
# {'word': '天气', 'label': 'PROPN'},
# {'word': '不错', 'label': 'VERB'},
# {'word': ',', 'label': 'NUM'},
# {'word': '适合', 'label': 'NOUN'},
# {'word': '出去', 'label': 'PART'},
# {'word': '游玩', 'label': 'ADV'},
# ]
# TODO @wenmeng.zwm support list of result check
Tasks.part_of_speech: [OutputKeys.WORD, OutputKeys.LABEL],
# }
Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS],
Tasks.part_of_speech: [OutputKeys.OUTPUT, OutputKeys.LABELS],

# TODO @wenmeng.zwm support list of result check
# named entity recognition result for single sample
# {
# "output": [


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -20,6 +20,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.word_segmentation:
(Pipelines.word_segmentation,
'damo/nlp_structbert_word-segmentation_chinese-base'),
Tasks.token_classification:
(Pipelines.part_of_speech,
'damo/nlp_structbert_part-of-speech_chinese-base'),
Tasks.named_entity_recognition:
(Pipelines.named_entity_recognition,
'damo/nlp_raner_named-entity-recognition_chinese-base-news'),


+ 15
- 15
modelscope/pipelines/nlp/__init__.py View File

@@ -9,21 +9,21 @@ if TYPE_CHECKING:
from .dialog_modeling_pipeline import DialogModelingPipeline
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
from .document_segmentation_pipeline import DocumentSegmentationPipeline
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline
from .fill_mask_pipeline import FillMaskPipeline
from .information_extraction_pipeline import InformationExtractionPipeline
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline
from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline
from .sequence_classification_pipeline import SequenceClassificationPipeline
from .summarization_pipeline import SummarizationPipeline
from .text_classification_pipeline import TextClassificationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
from .text_generation_pipeline import TextGenerationPipeline
from .token_classification_pipeline import TokenClassificationPipeline
from .translation_pipeline import TranslationPipeline
from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .summarization_pipeline import SummarizationPipeline
from .text_classification_pipeline import TextClassificationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline
from .relation_extraction_pipeline import RelationExtractionPipeline

else:
_import_structure = {
@@ -34,25 +34,25 @@ else:
'dialog_modeling_pipeline': ['DialogModelingPipeline'],
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'],
'document_segmentation_pipeline': ['DocumentSegmentationPipeline'],
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'],
'fill_mask_pipeline': ['FillMaskPipeline'],
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'information_extraction_pipeline': ['InformationExtractionPipeline'],
'single_sentence_classification_pipeline':
['SingleSentenceClassificationPipeline'],
'pair_sentence_classification_pipeline':
['PairSentenceClassificationPipeline'],
'sequence_classification_pipeline': ['SequenceClassificationPipeline'],
'single_sentence_classification_pipeline':
['SingleSentenceClassificationPipeline'],
'summarization_pipeline': ['SummarizationPipeline'],
'text_classification_pipeline': ['TextClassificationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'],
'text_generation_pipeline': ['TextGenerationPipeline'],
'token_classification_pipeline': ['TokenClassificationPipeline'],
'translation_pipeline': ['TranslationPipeline'],
'word_segmentation_pipeline': ['WordSegmentationPipeline'],
'zero_shot_classification_pipeline':
['ZeroShotClassificationPipeline'],
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'translation_pipeline': ['TranslationPipeline'],
'summarization_pipeline': ['SummarizationPipeline'],
'text_classification_pipeline': ['TextClassificationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'],
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'],
'relation_extraction_pipeline': ['RelationExtractionPipeline']
}

import sys


+ 92
- 0
modelscope/pipelines/nlp/token_classification_pipeline.py View File

@@ -0,0 +1,92 @@
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import (Preprocessor,
TokenClassificationPreprocessor)
from modelscope.utils.constant import Tasks

__all__ = ['TokenClassificationPipeline']


@PIPELINES.register_module(
Tasks.token_classification, module_name=Pipelines.part_of_speech)
class TokenClassificationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a token classification pipeline for prediction

Args:
model (str or Model): A model instance or a model local dir or a model id in the model hub.
preprocessor (Preprocessor): a preprocessor instance, must not be None.
"""
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TokenClassificationPreprocessor(
model.model_dir,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.id2label = getattr(model, 'id2label')
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \
'as a parameter or make sure the preprocessor has the attribute.'

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
text = inputs.pop(OutputKeys.TEXT)
with torch.no_grad():
return {
**self.model(inputs, **forward_params), OutputKeys.TEXT: text
}

def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

pred_list = inputs['predictions']
labels = []
for pre in pred_list:
labels.append(self.id2label[pre])
labels = labels[1:-1]
chunks = []
tags = []
chunk = ''
assert len(inputs['text']) == len(labels)
for token, label in zip(inputs['text'], labels):
if label[0] == 'B' or label[0] == 'I':
chunk += token
else:
chunk += token
chunks.append(chunk)
chunk = ''
tags.append(label.split('-')[-1])
if chunk:
chunks.append(chunk)
tags.append(label.split('-')[-1])
pos_result = []
seg_result = ' '.join(chunks)
for chunk, tag in zip(chunks, tags):
pos_result.append({OutputKeys.WORD: chunk, OutputKeys.LABEL: tag})
outputs = {
OutputKeys.OUTPUT: seg_result,
OutputKeys.LABELS: pos_result
}
return outputs

+ 9
- 10
modelscope/preprocessors/__init__.py View File

@@ -15,15 +15,14 @@ if TYPE_CHECKING:
ImageDenoisePreprocessor)
from .kws import WavToLists
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor)
from .nlp import (Tokenize, SequenceClassificationPreprocessor,
TextGenerationPreprocessor,
TokenClassificationPreprocessor,
SingleSentenceClassificationPreprocessor,
PairSentenceClassificationPreprocessor,
FillMaskPreprocessor, ZeroShotClassificationPreprocessor,
NERPreprocessor, TextErrorCorrectionPreprocessor,
FaqQuestionAnsweringPreprocessor,
RelationExtractionPreprocessor)
from .nlp import (
Tokenize, SequenceClassificationPreprocessor,
TextGenerationPreprocessor, TokenClassificationPreprocessor,
SingleSentenceClassificationPreprocessor,
PairSentenceClassificationPreprocessor, FillMaskPreprocessor,
ZeroShotClassificationPreprocessor, NERPreprocessor,
TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor,
SequenceLabelingPreprocessor, RelationExtractionPreprocessor)
from .slp import DocumentSegmentationPreprocessor
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
@@ -52,7 +51,7 @@ else:
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor',
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor',
'RelationExtractionPreprocessor'
],
'slp': ['DocumentSegmentationPreprocessor'],


+ 109
- 1
modelscope/preprocessors/nlp.py View File

@@ -5,9 +5,11 @@ import uuid
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
import torch
from transformers import AutoTokenizer, BertTokenizerFast

from modelscope.metainfo import Models, Preprocessors
from modelscope.models.nlp.structbert import SbertTokenizerFast
from modelscope.outputs import OutputKeys
from modelscope.utils.config import ConfigFields
from modelscope.utils.constant import Fields, InputFields, ModeKeys
@@ -23,7 +25,7 @@ __all__ = [
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor',
'RelationExtractionPreprocessor'
'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor'
]


@@ -627,6 +629,112 @@ class NERPreprocessor(Preprocessor):
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer)
class SequenceLabelingPreprocessor(Preprocessor):
"""The tokenizer preprocessor used in normal NER task.

NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition.
"""

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)

if 'lstm' in model_dir or 'gcnn' in model_dir:
self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir, use_fast=False)
elif 'structbert' in model_dir:
self.tokenizer = SbertTokenizerFast.from_pretrained(
model_dir, use_fast=False)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""

# preprocess the data for the model input
text = data
if self.is_split_into_words:
input_ids = []
label_mask = []
offset_mapping = []
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
label_mask.extend([1] + [0] * (len(subtoken_ids) - 1))
offset_mapping.extend([(offset, offset + 1)]
+ [(offset + 1, offset + 1)]
* (len(subtoken_ids) - 1))
if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
offset_mapping = offset_mapping[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
else:
encodings = self.tokenizer(
text,
add_special_tokens=True,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_offsets_mapping=True)
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
word_ids = encodings.word_ids()
label_mask = []
offset_mapping = []
for i in range(len(word_ids)):
if word_ids[i] is None:
label_mask.append(0)
elif word_ids[i] == word_ids[i - 1]:
label_mask.append(0)
offset_mapping[-1] = (offset_mapping[-1][0],
encodings['offset_mapping'][i][1])
else:
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
return {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.re_tokenizer)
class RelationExtractionPreprocessor(Preprocessor):


+ 24
- 6
modelscope/utils/hub.py View File

@@ -77,19 +77,26 @@ def auto_load(model: Union[str, List[str]]):
def get_model_type(model_dir):
"""Get the model type from the configuration.

This method will try to get the 'model.type' or 'model.model_type' field from the configuration.json file.
If this file does not exist, the method will try to get the 'model_type' field from the config.json.
This method will try to get the model type from 'model.backbone.type',
'model.type' or 'model.model_type' field in the configuration.json file. If
this file does not exist, the method will try to get the 'model_type' field
from the config.json.

@param model_dir: The local model dir to use.
@return: The model type string, returns None if nothing is found.
@param model_dir: The local model dir to use. @return: The model type
string, returns None if nothing is found.
"""
try:
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
config_file = osp.join(model_dir, 'config.json')
if osp.isfile(configuration_file):
cfg = Config.from_file(configuration_file)
return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \
else cfg.model.type
if hasattr(cfg.model, 'backbone'):
return cfg.model.backbone.type
elif hasattr(cfg.model,
'model_type') and not hasattr(cfg.model, 'type'):
return cfg.model.model_type
else:
return cfg.model.type
elif osp.isfile(config_file):
cfg = Config.from_file(config_file)
return cfg.model_type if hasattr(cfg, 'model_type') else None
@@ -123,13 +130,24 @@ def parse_label_mapping(model_dir):
if hasattr(config, ConfigFields.model) and hasattr(
config[ConfigFields.model], 'label2id'):
label2id = config[ConfigFields.model].label2id
elif hasattr(config, ConfigFields.model) and hasattr(
config[ConfigFields.model], 'id2label'):
id2label = config[ConfigFields.model].id2label
label2id = {label: id for id, label in id2label.items()}
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
config[ConfigFields.preprocessor], 'label2id'):
label2id = config[ConfigFields.preprocessor].label2id
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
config[ConfigFields.preprocessor], 'id2label'):
id2label = config[ConfigFields.preprocessor].id2label
label2id = {label: id for id, label in id2label.items()}

if label2id is None:
config_path = os.path.join(model_dir, 'config.json')
config = Config.from_file(config_path)
if hasattr(config, 'label2id'):
label2id = config.label2id
elif hasattr(config, 'id2label'):
id2label = config.id2label
label2id = {label: id for id, label in id2label.items()}
return label2id

+ 55
- 0
tests/pipelines/test_part_of_speech.py View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import shutil
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import TokenClassificationModel
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TokenClassificationPipeline
from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class PartOfSpeechTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_part-of-speech_chinese-base'
sentence = '今天天气不错,适合出去游玩'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = TokenClassificationPreprocessor(cache_path)
model = TokenClassificationModel.from_pretrained(cache_path)
pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.token_classification, model=model, preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'pipeline2: {pipeline2(input=self.sentence)}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = TokenClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.token_classification,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.token_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.token_classification)
print(pipeline_ins(input=self.sentence))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save