From 7ed4015bdcb60a4580fa8605cc9bc49a60b28e7f Mon Sep 17 00:00:00 2001 From: "dingkun.ldk" Date: Wed, 7 Sep 2022 11:57:30 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E6=94=AF=E6=8C=81=E8=AF=8D?= =?UTF-8?q?=E6=80=A7=E6=A0=87=E6=B3=A8=20=20=20=20=20=20=20=20=20Link:=20h?= =?UTF-8?q?ttps://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/998077?= =?UTF-8?q?4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modelscope/metainfo.py | 7 ++ modelscope/models/nlp/__init__.py | 43 ++++--- .../nlp/heads/sequence_classification_head.py | 1 - .../nlp/heads/token_classification_head.py | 42 +++++++ .../nlp/structbert/configuration_sbert.py | 2 +- modelscope/models/nlp/task_models/__init__.py | 2 + .../nlp/task_models/token_classification.py | 83 +++++++++++++ modelscope/models/nlp/token_classification.py | 1 + modelscope/outputs.py | 30 ++--- modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 30 ++--- .../nlp/token_classification_pipeline.py | 92 +++++++++++++++ modelscope/preprocessors/__init__.py | 19 ++- modelscope/preprocessors/nlp.py | 110 +++++++++++++++++- modelscope/utils/hub.py | 30 ++++- tests/pipelines/test_part_of_speech.py | 55 +++++++++ 16 files changed, 475 insertions(+), 75 deletions(-) create mode 100644 modelscope/models/nlp/heads/token_classification_head.py create mode 100644 modelscope/models/nlp/task_models/token_classification.py create mode 100644 modelscope/pipelines/nlp/token_classification_pipeline.py create mode 100644 tests/pipelines/test_part_of_speech.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 270c5aaf..994095c3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -55,7 +55,9 @@ class Models(object): space_modeling = 'space-modeling' star = 'star' tcrf = 'transformer-crf' + transformer_softmax = 'transformer-softmax' lcrf = 'lstm-crf' + gcnncrf = 'gcnn-crf' bart = 'bart' gpt3 = 'gpt3' plug = 'plug' @@ -82,6 +84,7 @@ class Models(object): class TaskModels(object): # nlp task text_classification = 'text-classification' + token_classification = 'token-classification' information_extraction = 'information-extraction' @@ -92,6 +95,8 @@ class Heads(object): bert_mlm = 'bert-mlm' # roberta mlm roberta_mlm = 'roberta-mlm' + # token cls + token_classification = 'token-classification' information_extraction = 'information-extraction' @@ -167,6 +172,7 @@ class Pipelines(object): # nlp tasks sentence_similarity = 'sentence-similarity' word_segmentation = 'word-segmentation' + part_of_speech = 'part-of-speech' named_entity_recognition = 'named-entity-recognition' text_generation = 'text-generation' sentiment_analysis = 'sentiment-analysis' @@ -272,6 +278,7 @@ class Preprocessors(object): sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' text_error_correction = 'text-error-correction' + sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' fill_mask = 'fill-mask' faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 9d54834c..40be8665 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -5,40 +5,39 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .backbones import SbertModel - from .heads import SequenceClassificationHead + from .bart_for_text_error_correction import BartForTextErrorCorrection from .bert_for_sequence_classification import BertForSequenceClassification from .bert_for_document_segmentation import BertForDocumentSegmentation from .csanmt_for_translation import CsanmtForTranslation - from .masked_language import ( - StructBertForMaskedLM, - VecoForMaskedLM, - BertForMaskedLM, - DebertaV2ForMaskedLM, - ) + from .heads import SequenceClassificationHead + from .gpt3 import GPT3ForTextGeneration + from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, + BertForMaskedLM, DebertaV2ForMaskedLM) from .nncrf_for_named_entity_recognition import ( TransformerCRFForNamedEntityRecognition, LSTMCRFForNamedEntityRecognition) - from .token_classification import SbertForTokenClassification + from .palm_v2 import PalmForTextGeneration + from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering + from .star_text_to_sql import StarForTextToSql from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification from .space import SpaceForDialogIntent from .space import SpaceForDialogModeling from .space import SpaceForDialogStateTracking - from .star_text_to_sql import StarForTextToSql from .task_models import (InformationExtractionModel, - SingleBackboneTaskModelBase) - from .bart_for_text_error_correction import BartForTextErrorCorrection - from .gpt3 import GPT3ForTextGeneration - from .plug import PlugForTextGeneration - from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering + SequenceClassificationModel, + SingleBackboneTaskModelBase, + TokenClassificationModel) + from .token_classification import SbertForTokenClassification else: _import_structure = { - 'star_text_to_sql': ['StarForTextToSql'], 'backbones': ['SbertModel'], - 'heads': ['SequenceClassificationHead'], - 'csanmt_for_translation': ['CsanmtForTranslation'], + 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], 'bert_for_sequence_classification': ['BertForSequenceClassification'], 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], + 'csanmt_for_translation': ['CsanmtForTranslation'], + 'heads': ['SequenceClassificationHead'], + 'gpt3': ['GPT3ForTextGeneration'], 'masked_language': [ 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', 'DebertaV2ForMaskedLM' @@ -48,7 +47,8 @@ else: 'LSTMCRFForNamedEntityRecognition' ], 'palm_v2': ['PalmForTextGeneration'], - 'token_classification': ['SbertForTokenClassification'], + 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], + 'star_text_to_sql': ['StarForTextToSql'], 'sequence_classification': ['VecoForSequenceClassification', 'SbertForSequenceClassification'], 'space': [ @@ -57,12 +57,9 @@ else: ], 'task_models': [ 'InformationExtractionModel', 'SequenceClassificationModel', - 'SingleBackboneTaskModelBase' + 'SingleBackboneTaskModelBase', 'TokenClassificationModel' ], - 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], - 'gpt3': ['GPT3ForTextGeneration'], - 'plug': ['PlugForTextGeneration'], - 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], + 'token_classification': ['SbertForTokenClassification'], } import sys diff --git a/modelscope/models/nlp/heads/sequence_classification_head.py b/modelscope/models/nlp/heads/sequence_classification_head.py index 92f3a4ec..e608f035 100644 --- a/modelscope/models/nlp/heads/sequence_classification_head.py +++ b/modelscope/models/nlp/heads/sequence_classification_head.py @@ -19,7 +19,6 @@ class SequenceClassificationHead(TorchHead): super().__init__(**kwargs) config = self.config self.num_labels = config.num_labels - self.config = config classifier_dropout = ( config['classifier_dropout'] if config.get('classifier_dropout') is not None else config['hidden_dropout_prob']) diff --git a/modelscope/models/nlp/heads/token_classification_head.py b/modelscope/models/nlp/heads/token_classification_head.py new file mode 100644 index 00000000..481524ae --- /dev/null +++ b/modelscope/models/nlp/heads/token_classification_head.py @@ -0,0 +1,42 @@ +from typing import Dict + +import torch +import torch.nn.functional as F +from torch import nn + +from modelscope.metainfo import Heads +from modelscope.models.base import TorchHead +from modelscope.models.builder import HEADS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks + + +@HEADS.register_module( + Tasks.token_classification, module_name=Heads.token_classification) +class TokenClassificationHead(TorchHead): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + config = self.config + self.num_labels = config.num_labels + classifier_dropout = ( + config['classifier_dropout'] if config.get('classifier_dropout') + is not None else config['hidden_dropout_prob']) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config['hidden_size'], + config['num_labels']) + + def forward(self, inputs=None): + if isinstance(inputs, dict): + assert inputs.get('sequence_output') is not None + sequence_output = inputs.get('sequence_output') + else: + sequence_output = inputs + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + return {OutputKeys.LOGITS: logits} + + def compute_loss(self, outputs: Dict[str, torch.Tensor], + labels) -> Dict[str, torch.Tensor]: + logits = outputs[OutputKeys.LOGITS] + return {OutputKeys.LOSS: F.cross_entropy(logits, labels)} diff --git a/modelscope/models/nlp/structbert/configuration_sbert.py b/modelscope/models/nlp/structbert/configuration_sbert.py index 374d4b62..a727a978 100644 --- a/modelscope/models/nlp/structbert/configuration_sbert.py +++ b/modelscope/models/nlp/structbert/configuration_sbert.py @@ -85,7 +85,7 @@ class SbertConfig(PretrainedConfig): If adv_bound not proveded, 2 * sigma will be used as the adv_bound factor """ - model_type = 'sbert' + model_type = 'structbert' def __init__(self, vocab_size=30522, diff --git a/modelscope/models/nlp/task_models/__init__.py b/modelscope/models/nlp/task_models/__init__.py index 49cf0ee4..7493ba74 100644 --- a/modelscope/models/nlp/task_models/__init__.py +++ b/modelscope/models/nlp/task_models/__init__.py @@ -7,12 +7,14 @@ if TYPE_CHECKING: from .information_extraction import InformationExtractionModel from .sequence_classification import SequenceClassificationModel from .task_model import SingleBackboneTaskModelBase + from .token_classification import TokenClassificationModel else: _import_structure = { 'information_extraction': ['InformationExtractionModel'], 'sequence_classification': ['SequenceClassificationModel'], 'task_model': ['SingleBackboneTaskModelBase'], + 'token_classification': ['TokenClassificationModel'], } import sys diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py new file mode 100644 index 00000000..29679838 --- /dev/null +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -0,0 +1,83 @@ +from typing import Any, Dict + +import numpy as np +import torch + +from modelscope.metainfo import TaskModels +from modelscope.models.builder import MODELS +from modelscope.models.nlp.task_models.task_model import \ + SingleBackboneTaskModelBase +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import parse_label_mapping +from modelscope.utils.tensor_utils import (torch_nested_detach, + torch_nested_numpify) + +__all__ = ['TokenClassificationModel'] + + +@MODELS.register_module( + Tasks.token_classification, module_name=TaskModels.token_classification) +class TokenClassificationModel(SingleBackboneTaskModelBase): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the token classification model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + if 'base_model_prefix' in kwargs: + self._base_model_prefix = kwargs['base_model_prefix'] + + backbone_cfg = self.cfg.backbone + head_cfg = self.cfg.head + + # get the num_labels + num_labels = kwargs.get('num_labels') + if num_labels is None: + label2id = parse_label_mapping(model_dir) + if label2id is not None and len(label2id) > 0: + num_labels = len(label2id) + self.id2label = {id: label for label, id in label2id.items()} + head_cfg['num_labels'] = num_labels + + self.build_backbone(backbone_cfg) + self.build_head(head_cfg) + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + labels = None + if OutputKeys.LABEL in input: + labels = input.pop(OutputKeys.LABEL) + elif OutputKeys.LABELS in input: + labels = input.pop(OutputKeys.LABELS) + + outputs = super().forward(input) + sequence_output, pooled_output = self.extract_backbone_outputs(outputs) + outputs = self.head.forward(sequence_output) + if labels in input: + loss = self.compute_loss(outputs, labels) + outputs.update(loss) + return outputs + + def extract_logits(self, outputs): + return outputs[OutputKeys.LOGITS].cpu().detach() + + def extract_backbone_outputs(self, outputs): + sequence_output = None + pooled_output = None + if hasattr(self.backbone, 'extract_sequence_outputs'): + sequence_output = self.backbone.extract_sequence_outputs(outputs) + return sequence_output, pooled_output + + def compute_loss(self, outputs, labels): + loss = self.head.compute_loss(outputs, labels) + return loss + + def postprocess(self, input, **kwargs): + logits = self.extract_logits(input) + pred = torch.argmax(logits[0], dim=-1) + pred = torch_nested_numpify(torch_nested_detach(pred)) + logits = torch_nested_numpify(torch_nested_detach(logits)) + res = {OutputKeys.PREDICTIONS: pred, OutputKeys.LOGITS: logits} + return res diff --git a/modelscope/models/nlp/token_classification.py b/modelscope/models/nlp/token_classification.py index 59d7d0cf..0be921d0 100644 --- a/modelscope/models/nlp/token_classification.py +++ b/modelscope/models/nlp/token_classification.py @@ -91,6 +91,7 @@ class TokenClassification(TorchModel): @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) +@MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert) @MODELS.register_module( Tasks.token_classification, module_name=Models.structbert) class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): diff --git a/modelscope/outputs.py b/modelscope/outputs.py index c6a7a619..6c7500bb 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -359,26 +359,20 @@ TASK_OUTPUTS = { # word segmentation result for single sample # { # "output": "今天 天气 不错 , 适合 出去 游玩" - # } - Tasks.word_segmentation: [OutputKeys.OUTPUT], - - # part-of-speech result for single sample - # [ - # {'word': '诸葛', 'label': 'PROPN'}, - # {'word': '亮', 'label': 'PROPN'}, - # {'word': '发明', 'label': 'VERB'}, - # {'word': '八', 'label': 'NUM'}, - # {'word': '阵', 'label': 'NOUN'}, - # {'word': '图', 'label': 'PART'}, - # {'word': '以', 'label': 'ADV'}, - # {'word': '利', 'label': 'VERB'}, - # {'word': '立营', 'label': 'VERB'}, - # {'word': '练兵', 'label': 'VERB'}, - # {'word': '.', 'label': 'PUNCT'} + # "labels": [ + # {'word': '今天', 'label': 'PROPN'}, + # {'word': '天气', 'label': 'PROPN'}, + # {'word': '不错', 'label': 'VERB'}, + # {'word': ',', 'label': 'NUM'}, + # {'word': '适合', 'label': 'NOUN'}, + # {'word': '出去', 'label': 'PART'}, + # {'word': '游玩', 'label': 'ADV'}, # ] - # TODO @wenmeng.zwm support list of result check - Tasks.part_of_speech: [OutputKeys.WORD, OutputKeys.LABEL], + # } + Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], + Tasks.part_of_speech: [OutputKeys.OUTPUT, OutputKeys.LABELS], + # TODO @wenmeng.zwm support list of result check # named entity recognition result for single sample # { # "output": [ diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 9f265fb8..fa79ca11 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -20,6 +20,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.word_segmentation: (Pipelines.word_segmentation, 'damo/nlp_structbert_word-segmentation_chinese-base'), + Tasks.token_classification: + (Pipelines.part_of_speech, + 'damo/nlp_structbert_part-of-speech_chinese-base'), Tasks.named_entity_recognition: (Pipelines.named_entity_recognition, 'damo/nlp_raner_named-entity-recognition_chinese-base-news'), diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 665e016d..9baeefbb 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -9,21 +9,21 @@ if TYPE_CHECKING: from .dialog_modeling_pipeline import DialogModelingPipeline from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline from .document_segmentation_pipeline import DocumentSegmentationPipeline + from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline from .fill_mask_pipeline import FillMaskPipeline from .information_extraction_pipeline import InformationExtractionPipeline from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline from .sequence_classification_pipeline import SequenceClassificationPipeline + from .summarization_pipeline import SummarizationPipeline + from .text_classification_pipeline import TextClassificationPipeline + from .text_error_correction_pipeline import TextErrorCorrectionPipeline from .text_generation_pipeline import TextGenerationPipeline + from .token_classification_pipeline import TokenClassificationPipeline from .translation_pipeline import TranslationPipeline from .word_segmentation_pipeline import WordSegmentationPipeline from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline - from .summarization_pipeline import SummarizationPipeline - from .text_classification_pipeline import TextClassificationPipeline - from .text_error_correction_pipeline import TextErrorCorrectionPipeline - from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline - from .relation_extraction_pipeline import RelationExtractionPipeline else: _import_structure = { @@ -34,25 +34,25 @@ else: 'dialog_modeling_pipeline': ['DialogModelingPipeline'], 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], + 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], 'fill_mask_pipeline': ['FillMaskPipeline'], + 'named_entity_recognition_pipeline': + ['NamedEntityRecognitionPipeline'], 'information_extraction_pipeline': ['InformationExtractionPipeline'], - 'single_sentence_classification_pipeline': - ['SingleSentenceClassificationPipeline'], 'pair_sentence_classification_pipeline': ['PairSentenceClassificationPipeline'], 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], + 'single_sentence_classification_pipeline': + ['SingleSentenceClassificationPipeline'], + 'summarization_pipeline': ['SummarizationPipeline'], + 'text_classification_pipeline': ['TextClassificationPipeline'], + 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], 'text_generation_pipeline': ['TextGenerationPipeline'], + 'token_classification_pipeline': ['TokenClassificationPipeline'], + 'translation_pipeline': ['TranslationPipeline'], 'word_segmentation_pipeline': ['WordSegmentationPipeline'], 'zero_shot_classification_pipeline': ['ZeroShotClassificationPipeline'], - 'named_entity_recognition_pipeline': - ['NamedEntityRecognitionPipeline'], - 'translation_pipeline': ['TranslationPipeline'], - 'summarization_pipeline': ['SummarizationPipeline'], - 'text_classification_pipeline': ['TextClassificationPipeline'], - 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], - 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], - 'relation_extraction_pipeline': ['RelationExtractionPipeline'] } import sys diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py new file mode 100644 index 00000000..804f8146 --- /dev/null +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import (Preprocessor, + TokenClassificationPreprocessor) +from modelscope.utils.constant import Tasks + +__all__ = ['TokenClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.part_of_speech) +class TokenClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a token classification pipeline for prediction + + Args: + model (str or Model): A model instance or a model local dir or a model id in the model hub. + preprocessor (Preprocessor): a preprocessor instance, must not be None. + """ + assert isinstance(model, str) or isinstance(model, Model), \ + 'model must be a single str or Model' + model = model if isinstance(model, + Model) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = TokenClassificationPreprocessor( + model.model_dir, + sequence_length=kwargs.pop('sequence_length', 128)) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.id2label = getattr(model, 'id2label') + assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ + 'as a parameter or make sure the preprocessor has the attribute.' + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + text = inputs.pop(OutputKeys.TEXT) + with torch.no_grad(): + return { + **self.model(inputs, **forward_params), OutputKeys.TEXT: text + } + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + pred_list = inputs['predictions'] + labels = [] + for pre in pred_list: + labels.append(self.id2label[pre]) + labels = labels[1:-1] + chunks = [] + tags = [] + chunk = '' + assert len(inputs['text']) == len(labels) + for token, label in zip(inputs['text'], labels): + if label[0] == 'B' or label[0] == 'I': + chunk += token + else: + chunk += token + chunks.append(chunk) + chunk = '' + tags.append(label.split('-')[-1]) + if chunk: + chunks.append(chunk) + tags.append(label.split('-')[-1]) + pos_result = [] + seg_result = ' '.join(chunks) + for chunk, tag in zip(chunks, tags): + pos_result.append({OutputKeys.WORD: chunk, OutputKeys.LABEL: tag}) + outputs = { + OutputKeys.OUTPUT: seg_result, + OutputKeys.LABELS: pos_result + } + return outputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 9f7d595e..0123b32e 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -15,15 +15,14 @@ if TYPE_CHECKING: ImageDenoisePreprocessor) from .kws import WavToLists from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) - from .nlp import (Tokenize, SequenceClassificationPreprocessor, - TextGenerationPreprocessor, - TokenClassificationPreprocessor, - SingleSentenceClassificationPreprocessor, - PairSentenceClassificationPreprocessor, - FillMaskPreprocessor, ZeroShotClassificationPreprocessor, - NERPreprocessor, TextErrorCorrectionPreprocessor, - FaqQuestionAnsweringPreprocessor, - RelationExtractionPreprocessor) + from .nlp import ( + Tokenize, SequenceClassificationPreprocessor, + TextGenerationPreprocessor, TokenClassificationPreprocessor, + SingleSentenceClassificationPreprocessor, + PairSentenceClassificationPreprocessor, FillMaskPreprocessor, + ZeroShotClassificationPreprocessor, NERPreprocessor, + TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, + SequenceLabelingPreprocessor, RelationExtractionPreprocessor) from .slp import DocumentSegmentationPreprocessor from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, @@ -52,7 +51,7 @@ else: 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', 'TextErrorCorrectionPreprocessor', - 'FaqQuestionAnsweringPreprocessor', + 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor' ], 'slp': ['DocumentSegmentationPreprocessor'], diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index cfb8c9e8..aaa83ed1 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -5,9 +5,11 @@ import uuid from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np +import torch from transformers import AutoTokenizer, BertTokenizerFast from modelscope.metainfo import Models, Preprocessors +from modelscope.models.nlp.structbert import SbertTokenizerFast from modelscope.outputs import OutputKeys from modelscope.utils.config import ConfigFields from modelscope.utils.constant import Fields, InputFields, ModeKeys @@ -23,7 +25,7 @@ __all__ = [ 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', - 'RelationExtractionPreprocessor' + 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor' ] @@ -627,6 +629,112 @@ class NERPreprocessor(Preprocessor): } +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) +class SequenceLabelingPreprocessor(Preprocessor): + """The tokenizer preprocessor used in normal NER task. + + NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + + if 'lstm' in model_dir or 'gcnn' in model_dir: + self.tokenizer = BertTokenizerFast.from_pretrained( + model_dir, use_fast=False) + elif 'structbert' in model_dir: + self.tokenizer = SbertTokenizerFast.from_pretrained( + model_dir, use_fast=False) + else: + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir, use_fast=False) + self.is_split_into_words = self.tokenizer.init_kwargs.get( + 'is_split_into_words', False) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + + # preprocess the data for the model input + text = data + if self.is_split_into_words: + input_ids = [] + label_mask = [] + offset_mapping = [] + for offset, token in enumerate(list(data)): + subtoken_ids = self.tokenizer.encode( + token, add_special_tokens=False) + if len(subtoken_ids) == 0: + subtoken_ids = [self.tokenizer.unk_token_id] + input_ids.extend(subtoken_ids) + label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) + offset_mapping.extend([(offset, offset + 1)] + + [(offset + 1, offset + 1)] + * (len(subtoken_ids) - 1)) + if len(input_ids) >= self.sequence_length - 2: + input_ids = input_ids[:self.sequence_length - 2] + label_mask = label_mask[:self.sequence_length - 2] + offset_mapping = offset_mapping[:self.sequence_length - 2] + input_ids = [self.tokenizer.cls_token_id + ] + input_ids + [self.tokenizer.sep_token_id] + label_mask = [0] + label_mask + [0] + attention_mask = [1] * len(input_ids) + else: + encodings = self.tokenizer( + text, + add_special_tokens=True, + padding=True, + truncation=True, + max_length=self.sequence_length, + return_offsets_mapping=True) + input_ids = encodings['input_ids'] + attention_mask = encodings['attention_mask'] + word_ids = encodings.word_ids() + label_mask = [] + offset_mapping = [] + for i in range(len(word_ids)): + if word_ids[i] is None: + label_mask.append(0) + elif word_ids[i] == word_ids[i - 1]: + label_mask.append(0) + offset_mapping[-1] = (offset_mapping[-1][0], + encodings['offset_mapping'][i][1]) + else: + label_mask.append(1) + offset_mapping.append(encodings['offset_mapping'][i]) + + if not self.is_transformer_based_model: + input_ids = input_ids[1:-1] + attention_mask = attention_mask[1:-1] + label_mask = label_mask[1:-1] + return { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'label_mask': label_mask, + 'offset_mapping': offset_mapping + } + + @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.re_tokenizer) class RelationExtractionPreprocessor(Preprocessor): diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index f79097fe..cf114b5e 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -77,19 +77,26 @@ def auto_load(model: Union[str, List[str]]): def get_model_type(model_dir): """Get the model type from the configuration. - This method will try to get the 'model.type' or 'model.model_type' field from the configuration.json file. - If this file does not exist, the method will try to get the 'model_type' field from the config.json. + This method will try to get the model type from 'model.backbone.type', + 'model.type' or 'model.model_type' field in the configuration.json file. If + this file does not exist, the method will try to get the 'model_type' field + from the config.json. - @param model_dir: The local model dir to use. - @return: The model type string, returns None if nothing is found. + @param model_dir: The local model dir to use. @return: The model type + string, returns None if nothing is found. """ try: configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) config_file = osp.join(model_dir, 'config.json') if osp.isfile(configuration_file): cfg = Config.from_file(configuration_file) - return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \ - else cfg.model.type + if hasattr(cfg.model, 'backbone'): + return cfg.model.backbone.type + elif hasattr(cfg.model, + 'model_type') and not hasattr(cfg.model, 'type'): + return cfg.model.model_type + else: + return cfg.model.type elif osp.isfile(config_file): cfg = Config.from_file(config_file) return cfg.model_type if hasattr(cfg, 'model_type') else None @@ -123,13 +130,24 @@ def parse_label_mapping(model_dir): if hasattr(config, ConfigFields.model) and hasattr( config[ConfigFields.model], 'label2id'): label2id = config[ConfigFields.model].label2id + elif hasattr(config, ConfigFields.model) and hasattr( + config[ConfigFields.model], 'id2label'): + id2label = config[ConfigFields.model].id2label + label2id = {label: id for id, label in id2label.items()} elif hasattr(config, ConfigFields.preprocessor) and hasattr( config[ConfigFields.preprocessor], 'label2id'): label2id = config[ConfigFields.preprocessor].label2id + elif hasattr(config, ConfigFields.preprocessor) and hasattr( + config[ConfigFields.preprocessor], 'id2label'): + id2label = config[ConfigFields.preprocessor].id2label + label2id = {label: id for id, label in id2label.items()} if label2id is None: config_path = os.path.join(model_dir, 'config.json') config = Config.from_file(config_path) if hasattr(config, 'label2id'): label2id = config.label2id + elif hasattr(config, 'id2label'): + id2label = config.id2label + label2id = {label: id for id, label in id2label.items()} return label2id diff --git a/tests/pipelines/test_part_of_speech.py b/tests/pipelines/test_part_of_speech.py new file mode 100644 index 00000000..25f4491c --- /dev/null +++ b/tests/pipelines/test_part_of_speech.py @@ -0,0 +1,55 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import TokenClassificationModel +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import TokenClassificationPipeline +from modelscope.preprocessors import TokenClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class PartOfSpeechTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_part-of-speech_chinese-base' + sentence = '今天天气不错,适合出去游玩' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = TokenClassificationPreprocessor(cache_path) + model = TokenClassificationModel.from_pretrained(cache_path) + pipeline1 = TokenClassificationPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.token_classification, model=model, preprocessor=tokenizer) + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'pipeline2: {pipeline2(input=self.sentence)}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = TokenClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.token_classification, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.token_classification, model=self.model_id) + print(pipeline_ins(input=self.sentence)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.token_classification) + print(pipeline_ins(input=self.sentence)) + + +if __name__ == '__main__': + unittest.main()