Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9980774master
@@ -55,7 +55,9 @@ class Models(object): | |||
space_modeling = 'space-modeling' | |||
star = 'star' | |||
tcrf = 'transformer-crf' | |||
transformer_softmax = 'transformer-softmax' | |||
lcrf = 'lstm-crf' | |||
gcnncrf = 'gcnn-crf' | |||
bart = 'bart' | |||
gpt3 = 'gpt3' | |||
plug = 'plug' | |||
@@ -82,6 +84,7 @@ class Models(object): | |||
class TaskModels(object): | |||
# nlp task | |||
text_classification = 'text-classification' | |||
token_classification = 'token-classification' | |||
information_extraction = 'information-extraction' | |||
@@ -92,6 +95,8 @@ class Heads(object): | |||
bert_mlm = 'bert-mlm' | |||
# roberta mlm | |||
roberta_mlm = 'roberta-mlm' | |||
# token cls | |||
token_classification = 'token-classification' | |||
information_extraction = 'information-extraction' | |||
@@ -167,6 +172,7 @@ class Pipelines(object): | |||
# nlp tasks | |||
sentence_similarity = 'sentence-similarity' | |||
word_segmentation = 'word-segmentation' | |||
part_of_speech = 'part-of-speech' | |||
named_entity_recognition = 'named-entity-recognition' | |||
text_generation = 'text-generation' | |||
sentiment_analysis = 'sentiment-analysis' | |||
@@ -272,6 +278,7 @@ class Preprocessors(object): | |||
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
text_error_correction = 'text-error-correction' | |||
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
fill_mask = 'fill-mask' | |||
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
@@ -5,40 +5,39 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .backbones import SbertModel | |||
from .heads import SequenceClassificationHead | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .bert_for_sequence_classification import BertForSequenceClassification | |||
from .bert_for_document_segmentation import BertForDocumentSegmentation | |||
from .csanmt_for_translation import CsanmtForTranslation | |||
from .masked_language import ( | |||
StructBertForMaskedLM, | |||
VecoForMaskedLM, | |||
BertForMaskedLM, | |||
DebertaV2ForMaskedLM, | |||
) | |||
from .heads import SequenceClassificationHead | |||
from .gpt3 import GPT3ForTextGeneration | |||
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||
BertForMaskedLM, DebertaV2ForMaskedLM) | |||
from .nncrf_for_named_entity_recognition import ( | |||
TransformerCRFForNamedEntityRecognition, | |||
LSTMCRFForNamedEntityRecognition) | |||
from .token_classification import SbertForTokenClassification | |||
from .palm_v2 import PalmForTextGeneration | |||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
from .star_text_to_sql import StarForTextToSql | |||
from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | |||
from .space import SpaceForDialogIntent | |||
from .space import SpaceForDialogModeling | |||
from .space import SpaceForDialogStateTracking | |||
from .star_text_to_sql import StarForTextToSql | |||
from .task_models import (InformationExtractionModel, | |||
SingleBackboneTaskModelBase) | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .gpt3 import GPT3ForTextGeneration | |||
from .plug import PlugForTextGeneration | |||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
SequenceClassificationModel, | |||
SingleBackboneTaskModelBase, | |||
TokenClassificationModel) | |||
from .token_classification import SbertForTokenClassification | |||
else: | |||
_import_structure = { | |||
'star_text_to_sql': ['StarForTextToSql'], | |||
'backbones': ['SbertModel'], | |||
'heads': ['SequenceClassificationHead'], | |||
'csanmt_for_translation': ['CsanmtForTranslation'], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'bert_for_sequence_classification': ['BertForSequenceClassification'], | |||
'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | |||
'csanmt_for_translation': ['CsanmtForTranslation'], | |||
'heads': ['SequenceClassificationHead'], | |||
'gpt3': ['GPT3ForTextGeneration'], | |||
'masked_language': [ | |||
'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', | |||
'DebertaV2ForMaskedLM' | |||
@@ -48,7 +47,8 @@ else: | |||
'LSTMCRFForNamedEntityRecognition' | |||
], | |||
'palm_v2': ['PalmForTextGeneration'], | |||
'token_classification': ['SbertForTokenClassification'], | |||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
'star_text_to_sql': ['StarForTextToSql'], | |||
'sequence_classification': | |||
['VecoForSequenceClassification', 'SbertForSequenceClassification'], | |||
'space': [ | |||
@@ -57,12 +57,9 @@ else: | |||
], | |||
'task_models': [ | |||
'InformationExtractionModel', 'SequenceClassificationModel', | |||
'SingleBackboneTaskModelBase' | |||
'SingleBackboneTaskModelBase', 'TokenClassificationModel' | |||
], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'gpt3': ['GPT3ForTextGeneration'], | |||
'plug': ['PlugForTextGeneration'], | |||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
'token_classification': ['SbertForTokenClassification'], | |||
} | |||
import sys | |||
@@ -19,7 +19,6 @@ class SequenceClassificationHead(TorchHead): | |||
super().__init__(**kwargs) | |||
config = self.config | |||
self.num_labels = config.num_labels | |||
self.config = config | |||
classifier_dropout = ( | |||
config['classifier_dropout'] if config.get('classifier_dropout') | |||
is not None else config['hidden_dropout_prob']) | |||
@@ -0,0 +1,42 @@ | |||
from typing import Dict | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from modelscope.metainfo import Heads | |||
from modelscope.models.base import TorchHead | |||
from modelscope.models.builder import HEADS | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
@HEADS.register_module( | |||
Tasks.token_classification, module_name=Heads.token_classification) | |||
class TokenClassificationHead(TorchHead): | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
config = self.config | |||
self.num_labels = config.num_labels | |||
classifier_dropout = ( | |||
config['classifier_dropout'] if config.get('classifier_dropout') | |||
is not None else config['hidden_dropout_prob']) | |||
self.dropout = nn.Dropout(classifier_dropout) | |||
self.classifier = nn.Linear(config['hidden_size'], | |||
config['num_labels']) | |||
def forward(self, inputs=None): | |||
if isinstance(inputs, dict): | |||
assert inputs.get('sequence_output') is not None | |||
sequence_output = inputs.get('sequence_output') | |||
else: | |||
sequence_output = inputs | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
return {OutputKeys.LOGITS: logits} | |||
def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
labels) -> Dict[str, torch.Tensor]: | |||
logits = outputs[OutputKeys.LOGITS] | |||
return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} |
@@ -85,7 +85,7 @@ class SbertConfig(PretrainedConfig): | |||
If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor | |||
""" | |||
model_type = 'sbert' | |||
model_type = 'structbert' | |||
def __init__(self, | |||
vocab_size=30522, | |||
@@ -7,12 +7,14 @@ if TYPE_CHECKING: | |||
from .information_extraction import InformationExtractionModel | |||
from .sequence_classification import SequenceClassificationModel | |||
from .task_model import SingleBackboneTaskModelBase | |||
from .token_classification import TokenClassificationModel | |||
else: | |||
_import_structure = { | |||
'information_extraction': ['InformationExtractionModel'], | |||
'sequence_classification': ['SequenceClassificationModel'], | |||
'task_model': ['SingleBackboneTaskModelBase'], | |||
'token_classification': ['TokenClassificationModel'], | |||
} | |||
import sys | |||
@@ -0,0 +1,83 @@ | |||
from typing import Any, Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import TaskModels | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.task_models.task_model import \ | |||
SingleBackboneTaskModelBase | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.hub import parse_label_mapping | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
__all__ = ['TokenClassificationModel'] | |||
@MODELS.register_module( | |||
Tasks.token_classification, module_name=TaskModels.token_classification) | |||
class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the token classification model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
if 'base_model_prefix' in kwargs: | |||
self._base_model_prefix = kwargs['base_model_prefix'] | |||
backbone_cfg = self.cfg.backbone | |||
head_cfg = self.cfg.head | |||
# get the num_labels | |||
num_labels = kwargs.get('num_labels') | |||
if num_labels is None: | |||
label2id = parse_label_mapping(model_dir) | |||
if label2id is not None and len(label2id) > 0: | |||
num_labels = len(label2id) | |||
self.id2label = {id: label for label, id in label2id.items()} | |||
head_cfg['num_labels'] = num_labels | |||
self.build_backbone(backbone_cfg) | |||
self.build_head(head_cfg) | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
labels = None | |||
if OutputKeys.LABEL in input: | |||
labels = input.pop(OutputKeys.LABEL) | |||
elif OutputKeys.LABELS in input: | |||
labels = input.pop(OutputKeys.LABELS) | |||
outputs = super().forward(input) | |||
sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
outputs = self.head.forward(sequence_output) | |||
if labels in input: | |||
loss = self.compute_loss(outputs, labels) | |||
outputs.update(loss) | |||
return outputs | |||
def extract_logits(self, outputs): | |||
return outputs[OutputKeys.LOGITS].cpu().detach() | |||
def extract_backbone_outputs(self, outputs): | |||
sequence_output = None | |||
pooled_output = None | |||
if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
return sequence_output, pooled_output | |||
def compute_loss(self, outputs, labels): | |||
loss = self.head.compute_loss(outputs, labels) | |||
return loss | |||
def postprocess(self, input, **kwargs): | |||
logits = self.extract_logits(input) | |||
pred = torch.argmax(logits[0], dim=-1) | |||
pred = torch_nested_numpify(torch_nested_detach(pred)) | |||
logits = torch_nested_numpify(torch_nested_detach(logits)) | |||
res = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits} | |||
return res |
@@ -91,6 +91,7 @@ class TokenClassification(TorchModel): | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | |||
@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) | |||
@MODELS.register_module( | |||
Tasks.token_classification, module_name=Models.structbert) | |||
class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | |||
@@ -359,26 +359,20 @@ TASK_OUTPUTS = { | |||
# word segmentation result for single sample | |||
# { | |||
# "output": "今天 天气 不错 , 适合 出去 游玩" | |||
# } | |||
Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||
# part-of-speech result for single sample | |||
# [ | |||
# {'word': '诸葛', 'label': 'PROPN'}, | |||
# {'word': '亮', 'label': 'PROPN'}, | |||
# {'word': '发明', 'label': 'VERB'}, | |||
# {'word': '八', 'label': 'NUM'}, | |||
# {'word': '阵', 'label': 'NOUN'}, | |||
# {'word': '图', 'label': 'PART'}, | |||
# {'word': '以', 'label': 'ADV'}, | |||
# {'word': '利', 'label': 'VERB'}, | |||
# {'word': '立营', 'label': 'VERB'}, | |||
# {'word': '练兵', 'label': 'VERB'}, | |||
# {'word': '.', 'label': 'PUNCT'} | |||
# "labels": [ | |||
# {'word': '今天', 'label': 'PROPN'}, | |||
# {'word': '天气', 'label': 'PROPN'}, | |||
# {'word': '不错', 'label': 'VERB'}, | |||
# {'word': ',', 'label': 'NUM'}, | |||
# {'word': '适合', 'label': 'NOUN'}, | |||
# {'word': '出去', 'label': 'PART'}, | |||
# {'word': '游玩', 'label': 'ADV'}, | |||
# ] | |||
# TODO @wenmeng.zwm support list of result check | |||
Tasks.part_of_speech: [OutputKeys.WORD, OutputKeys.LABEL], | |||
# } | |||
Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||
Tasks.part_of_speech: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||
# TODO @wenmeng.zwm support list of result check | |||
# named entity recognition result for single sample | |||
# { | |||
# "output": [ | |||
@@ -20,6 +20,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.word_segmentation: | |||
(Pipelines.word_segmentation, | |||
'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
Tasks.token_classification: | |||
(Pipelines.part_of_speech, | |||
'damo/nlp_structbert_part-of-speech_chinese-base'), | |||
Tasks.named_entity_recognition: | |||
(Pipelines.named_entity_recognition, | |||
'damo/nlp_raner_named-entity-recognition_chinese-base-news'), | |||
@@ -9,21 +9,21 @@ if TYPE_CHECKING: | |||
from .dialog_modeling_pipeline import DialogModelingPipeline | |||
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
from .fill_mask_pipeline import FillMaskPipeline | |||
from .information_extraction_pipeline import InformationExtractionPipeline | |||
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | |||
from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | |||
from .sequence_classification_pipeline import SequenceClassificationPipeline | |||
from .summarization_pipeline import SummarizationPipeline | |||
from .text_classification_pipeline import TextClassificationPipeline | |||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
from .text_generation_pipeline import TextGenerationPipeline | |||
from .token_classification_pipeline import TokenClassificationPipeline | |||
from .translation_pipeline import TranslationPipeline | |||
from .word_segmentation_pipeline import WordSegmentationPipeline | |||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
from .summarization_pipeline import SummarizationPipeline | |||
from .text_classification_pipeline import TextClassificationPipeline | |||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
from .relation_extraction_pipeline import RelationExtractionPipeline | |||
else: | |||
_import_structure = { | |||
@@ -34,25 +34,25 @@ else: | |||
'dialog_modeling_pipeline': ['DialogModelingPipeline'], | |||
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | |||
'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | |||
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
'fill_mask_pipeline': ['FillMaskPipeline'], | |||
'named_entity_recognition_pipeline': | |||
['NamedEntityRecognitionPipeline'], | |||
'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
'single_sentence_classification_pipeline': | |||
['SingleSentenceClassificationPipeline'], | |||
'pair_sentence_classification_pipeline': | |||
['PairSentenceClassificationPipeline'], | |||
'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | |||
'single_sentence_classification_pipeline': | |||
['SingleSentenceClassificationPipeline'], | |||
'summarization_pipeline': ['SummarizationPipeline'], | |||
'text_classification_pipeline': ['TextClassificationPipeline'], | |||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
'text_generation_pipeline': ['TextGenerationPipeline'], | |||
'token_classification_pipeline': ['TokenClassificationPipeline'], | |||
'translation_pipeline': ['TranslationPipeline'], | |||
'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
'zero_shot_classification_pipeline': | |||
['ZeroShotClassificationPipeline'], | |||
'named_entity_recognition_pipeline': | |||
['NamedEntityRecognitionPipeline'], | |||
'translation_pipeline': ['TranslationPipeline'], | |||
'summarization_pipeline': ['SummarizationPipeline'], | |||
'text_classification_pipeline': ['TextClassificationPipeline'], | |||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
'relation_extraction_pipeline': ['RelationExtractionPipeline'] | |||
} | |||
import sys | |||
@@ -0,0 +1,92 @@ | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import (Preprocessor, | |||
TokenClassificationPreprocessor) | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['TokenClassificationPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.token_classification, module_name=Pipelines.part_of_speech) | |||
class TokenClassificationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""use `model` and `preprocessor` to create a token classification pipeline for prediction | |||
Args: | |||
model (str or Model): A model instance or a model local dir or a model id in the model hub. | |||
preprocessor (Preprocessor): a preprocessor instance, must not be None. | |||
""" | |||
assert isinstance(model, str) or isinstance(model, Model), \ | |||
'model must be a single str or Model' | |||
model = model if isinstance(model, | |||
Model) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = TokenClassificationPreprocessor( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 128)) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.id2label = getattr(model, 'id2label') | |||
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
'as a parameter or make sure the preprocessor has the attribute.' | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
text = inputs.pop(OutputKeys.TEXT) | |||
with torch.no_grad(): | |||
return { | |||
**self.model(inputs, **forward_params), OutputKeys.TEXT: text | |||
} | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
pred_list = inputs['predictions'] | |||
labels = [] | |||
for pre in pred_list: | |||
labels.append(self.id2label[pre]) | |||
labels = labels[1:-1] | |||
chunks = [] | |||
tags = [] | |||
chunk = '' | |||
assert len(inputs['text']) == len(labels) | |||
for token, label in zip(inputs['text'], labels): | |||
if label[0] == 'B' or label[0] == 'I': | |||
chunk += token | |||
else: | |||
chunk += token | |||
chunks.append(chunk) | |||
chunk = '' | |||
tags.append(label.split('-')[-1]) | |||
if chunk: | |||
chunks.append(chunk) | |||
tags.append(label.split('-')[-1]) | |||
pos_result = [] | |||
seg_result = ' '.join(chunks) | |||
for chunk, tag in zip(chunks, tags): | |||
pos_result.append({OutputKeys.WORD: chunk, OutputKeys.LABEL: tag}) | |||
outputs = { | |||
OutputKeys.OUTPUT: seg_result, | |||
OutputKeys.LABELS: pos_result | |||
} | |||
return outputs |
@@ -15,15 +15,14 @@ if TYPE_CHECKING: | |||
ImageDenoisePreprocessor) | |||
from .kws import WavToLists | |||
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) | |||
from .nlp import (Tokenize, SequenceClassificationPreprocessor, | |||
TextGenerationPreprocessor, | |||
TokenClassificationPreprocessor, | |||
SingleSentenceClassificationPreprocessor, | |||
PairSentenceClassificationPreprocessor, | |||
FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | |||
NERPreprocessor, TextErrorCorrectionPreprocessor, | |||
FaqQuestionAnsweringPreprocessor, | |||
RelationExtractionPreprocessor) | |||
from .nlp import ( | |||
Tokenize, SequenceClassificationPreprocessor, | |||
TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||
SingleSentenceClassificationPreprocessor, | |||
PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||
ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
SequenceLabelingPreprocessor, RelationExtractionPreprocessor) | |||
from .slp import DocumentSegmentationPreprocessor | |||
from .space import (DialogIntentPredictionPreprocessor, | |||
DialogModelingPreprocessor, | |||
@@ -52,7 +51,7 @@ else: | |||
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
'TextErrorCorrectionPreprocessor', | |||
'FaqQuestionAnsweringPreprocessor', | |||
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
'RelationExtractionPreprocessor' | |||
], | |||
'slp': ['DocumentSegmentationPreprocessor'], | |||
@@ -5,9 +5,11 @@ import uuid | |||
from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
import numpy as np | |||
import torch | |||
from transformers import AutoTokenizer, BertTokenizerFast | |||
from modelscope.metainfo import Models, Preprocessors | |||
from modelscope.models.nlp.structbert import SbertTokenizerFast | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.config import ConfigFields | |||
from modelscope.utils.constant import Fields, InputFields, ModeKeys | |||
@@ -23,7 +25,7 @@ __all__ = [ | |||
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||
'RelationExtractionPreprocessor' | |||
'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor' | |||
] | |||
@@ -627,6 +629,112 @@ class NERPreprocessor(Preprocessor): | |||
} | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) | |||
class SequenceLabelingPreprocessor(Preprocessor): | |||
"""The tokenizer preprocessor used in normal NER task. | |||
NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||
""" | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
self.model_dir: str = model_dir | |||
self.sequence_length = kwargs.pop('sequence_length', 512) | |||
if 'lstm' in model_dir or 'gcnn' in model_dir: | |||
self.tokenizer = BertTokenizerFast.from_pretrained( | |||
model_dir, use_fast=False) | |||
elif 'structbert' in model_dir: | |||
self.tokenizer = SbertTokenizerFast.from_pretrained( | |||
model_dir, use_fast=False) | |||
else: | |||
self.tokenizer = AutoTokenizer.from_pretrained( | |||
model_dir, use_fast=False) | |||
self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
'is_split_into_words', False) | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data (str): a sentence | |||
Example: | |||
'you are so handsome.' | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
# preprocess the data for the model input | |||
text = data | |||
if self.is_split_into_words: | |||
input_ids = [] | |||
label_mask = [] | |||
offset_mapping = [] | |||
for offset, token in enumerate(list(data)): | |||
subtoken_ids = self.tokenizer.encode( | |||
token, add_special_tokens=False) | |||
if len(subtoken_ids) == 0: | |||
subtoken_ids = [self.tokenizer.unk_token_id] | |||
input_ids.extend(subtoken_ids) | |||
label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) | |||
offset_mapping.extend([(offset, offset + 1)] | |||
+ [(offset + 1, offset + 1)] | |||
* (len(subtoken_ids) - 1)) | |||
if len(input_ids) >= self.sequence_length - 2: | |||
input_ids = input_ids[:self.sequence_length - 2] | |||
label_mask = label_mask[:self.sequence_length - 2] | |||
offset_mapping = offset_mapping[:self.sequence_length - 2] | |||
input_ids = [self.tokenizer.cls_token_id | |||
] + input_ids + [self.tokenizer.sep_token_id] | |||
label_mask = [0] + label_mask + [0] | |||
attention_mask = [1] * len(input_ids) | |||
else: | |||
encodings = self.tokenizer( | |||
text, | |||
add_special_tokens=True, | |||
padding=True, | |||
truncation=True, | |||
max_length=self.sequence_length, | |||
return_offsets_mapping=True) | |||
input_ids = encodings['input_ids'] | |||
attention_mask = encodings['attention_mask'] | |||
word_ids = encodings.word_ids() | |||
label_mask = [] | |||
offset_mapping = [] | |||
for i in range(len(word_ids)): | |||
if word_ids[i] is None: | |||
label_mask.append(0) | |||
elif word_ids[i] == word_ids[i - 1]: | |||
label_mask.append(0) | |||
offset_mapping[-1] = (offset_mapping[-1][0], | |||
encodings['offset_mapping'][i][1]) | |||
else: | |||
label_mask.append(1) | |||
offset_mapping.append(encodings['offset_mapping'][i]) | |||
if not self.is_transformer_based_model: | |||
input_ids = input_ids[1:-1] | |||
attention_mask = attention_mask[1:-1] | |||
label_mask = label_mask[1:-1] | |||
return { | |||
'text': text, | |||
'input_ids': input_ids, | |||
'attention_mask': attention_mask, | |||
'label_mask': label_mask, | |||
'offset_mapping': offset_mapping | |||
} | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=Preprocessors.re_tokenizer) | |||
class RelationExtractionPreprocessor(Preprocessor): | |||
@@ -77,19 +77,26 @@ def auto_load(model: Union[str, List[str]]): | |||
def get_model_type(model_dir): | |||
"""Get the model type from the configuration. | |||
This method will try to get the 'model.type' or 'model.model_type' field from the configuration.json file. | |||
If this file does not exist, the method will try to get the 'model_type' field from the config.json. | |||
This method will try to get the model type from 'model.backbone.type', | |||
'model.type' or 'model.model_type' field in the configuration.json file. If | |||
this file does not exist, the method will try to get the 'model_type' field | |||
from the config.json. | |||
@param model_dir: The local model dir to use. | |||
@return: The model type string, returns None if nothing is found. | |||
@param model_dir: The local model dir to use. @return: The model type | |||
string, returns None if nothing is found. | |||
""" | |||
try: | |||
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) | |||
config_file = osp.join(model_dir, 'config.json') | |||
if osp.isfile(configuration_file): | |||
cfg = Config.from_file(configuration_file) | |||
return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \ | |||
else cfg.model.type | |||
if hasattr(cfg.model, 'backbone'): | |||
return cfg.model.backbone.type | |||
elif hasattr(cfg.model, | |||
'model_type') and not hasattr(cfg.model, 'type'): | |||
return cfg.model.model_type | |||
else: | |||
return cfg.model.type | |||
elif osp.isfile(config_file): | |||
cfg = Config.from_file(config_file) | |||
return cfg.model_type if hasattr(cfg, 'model_type') else None | |||
@@ -123,13 +130,24 @@ def parse_label_mapping(model_dir): | |||
if hasattr(config, ConfigFields.model) and hasattr( | |||
config[ConfigFields.model], 'label2id'): | |||
label2id = config[ConfigFields.model].label2id | |||
elif hasattr(config, ConfigFields.model) and hasattr( | |||
config[ConfigFields.model], 'id2label'): | |||
id2label = config[ConfigFields.model].id2label | |||
label2id = {label: id for id, label in id2label.items()} | |||
elif hasattr(config, ConfigFields.preprocessor) and hasattr( | |||
config[ConfigFields.preprocessor], 'label2id'): | |||
label2id = config[ConfigFields.preprocessor].label2id | |||
elif hasattr(config, ConfigFields.preprocessor) and hasattr( | |||
config[ConfigFields.preprocessor], 'id2label'): | |||
id2label = config[ConfigFields.preprocessor].id2label | |||
label2id = {label: id for id, label in id2label.items()} | |||
if label2id is None: | |||
config_path = os.path.join(model_dir, 'config.json') | |||
config = Config.from_file(config_path) | |||
if hasattr(config, 'label2id'): | |||
label2id = config.label2id | |||
elif hasattr(config, 'id2label'): | |||
id2label = config.id2label | |||
label2id = {label: id for id, label in id2label.items()} | |||
return label2id |
@@ -0,0 +1,55 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import shutil | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import TokenClassificationModel | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.nlp import TokenClassificationPipeline | |||
from modelscope.preprocessors import TokenClassificationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class PartOfSpeechTest(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_part-of-speech_chinese-base' | |||
sentence = '今天天气不错,适合出去游玩' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
tokenizer = TokenClassificationPreprocessor(cache_path) | |||
model = TokenClassificationModel.from_pretrained(cache_path) | |||
pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer) | |||
pipeline2 = pipeline( | |||
Tasks.token_classification, model=model, preprocessor=tokenizer) | |||
print(f'sentence: {self.sentence}\n' | |||
f'pipeline1:{pipeline1(input=self.sentence)}') | |||
print() | |||
print(f'pipeline2: {pipeline2(input=self.sentence)}') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
tokenizer = TokenClassificationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.token_classification, | |||
model=model, | |||
preprocessor=tokenizer) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.token_classification, model=self.model_id) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.token_classification) | |||
print(pipeline_ins(input=self.sentence)) | |||
if __name__ == '__main__': | |||
unittest.main() |