@@ -71,6 +71,7 @@ class Models(object): | |||
space_T_en = 'space-T-en' | |||
space_T_cn = 'space-T-cn' | |||
tcrf = 'transformer-crf' | |||
token_classification_for_ner = 'token-classification-for-ner' | |||
tcrf_wseg = 'transformer-crf-for-word-segmentation' | |||
transformer_softmax = 'transformer-softmax' | |||
lcrf = 'lstm-crf' | |||
@@ -40,11 +40,13 @@ if TYPE_CHECKING: | |||
FeatureExtractionModel, | |||
InformationExtractionModel, | |||
LSTMCRFForNamedEntityRecognition, | |||
LSTMCRFForWordSegmentation, | |||
SequenceClassificationModel, | |||
SingleBackboneTaskModelBase, | |||
TaskModelForTextGeneration, | |||
TokenClassificationModel, | |||
TransformerCRFForNamedEntityRecognition, | |||
TransformerCRFForWordSegmentation, | |||
) | |||
from .veco import (VecoConfig, VecoForMaskedLM, | |||
VecoForSequenceClassification, | |||
@@ -14,6 +14,8 @@ from modelscope.utils.constant import Tasks | |||
@HEADS.register_module( | |||
Tasks.token_classification, module_name=Heads.token_classification) | |||
@HEADS.register_module( | |||
Tasks.named_entity_recognition, module_name=Heads.token_classification) | |||
@HEADS.register_module( | |||
Tasks.part_of_speech, module_name=Heads.token_classification) | |||
class TokenClassificationHead(TorchHead): | |||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .mglm_for_text_summarization import mGlmForSummarization | |||
from .mglm_for_text_summarization import MGLMForTextSummarization | |||
else: | |||
_import_structure = { | |||
'mglm_for_text_summarization': ['MGLMForTextSummarization'], | |||
@@ -9,10 +9,8 @@ if TYPE_CHECKING: | |||
from .fill_mask import FillMaskModel | |||
from .nncrf_for_named_entity_recognition import ( | |||
LSTMCRFForNamedEntityRecognition, | |||
TransformerCRFForNamedEntityRecognition, | |||
) | |||
from .nncrf_for_word_segmentation import ( | |||
LSTMCRFForWordSegmentation, | |||
TransformerCRFForNamedEntityRecognition, | |||
TransformerCRFForWordSegmentation, | |||
) | |||
from .sequence_classification import SequenceClassificationModel | |||
@@ -26,11 +24,11 @@ else: | |||
'feature_extraction': ['FeatureExtractionModel'], | |||
'fill_mask': ['FillMaskModel'], | |||
'nncrf_for_named_entity_recognition': [ | |||
'LSTMCRFForNamedEntityRecognition', | |||
'LSTMCRFForWordSegmentation', | |||
'TransformerCRFForNamedEntityRecognition', | |||
'LSTMCRFForNamedEntityRecognition' | |||
'TransformerCRFForWordSegmentation', | |||
], | |||
'nncrf_for_word_segmentation': | |||
['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'], | |||
'sequence_classification': ['SequenceClassificationModel'], | |||
'task_model': ['SingleBackboneTaskModelBase'], | |||
'token_classification': ['TokenClassificationModel'], | |||
@@ -167,6 +167,14 @@ class TransformerCRFForNamedEntityRecognition( | |||
return model | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) | |||
class TransformerCRFForWordSegmentation(TransformerCRFForNamedEntityRecognition | |||
): | |||
"""This model wraps the TransformerCRF model to register into model sets. | |||
""" | |||
pass | |||
@MODELS.register_module( | |||
Tasks.named_entity_recognition, module_name=Models.lcrf) | |||
class LSTMCRFForNamedEntityRecognition( | |||
@@ -185,6 +193,11 @@ class LSTMCRFForNamedEntityRecognition( | |||
return model | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) | |||
class LSTMCRFForWordSegmentation(LSTMCRFForNamedEntityRecognition): | |||
pass | |||
class TransformerCRF(nn.Module): | |||
"""A transformer based model to NER tasks. | |||
@@ -1,639 +0,0 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. | |||
# The CRF implementation borrows mostly from AllenNLP CRF module (https://github.com/allenai/allennlp) | |||
# and pytorch-crf (https://github.com/kmkurn/pytorch-crf) with some modifications. | |||
import os | |||
from typing import Any, Dict, List, Optional | |||
import torch | |||
import torch.nn as nn | |||
from transformers import AutoConfig, AutoModel | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.outputs import TokenClassifierWithPredictionsOutput | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
__all__ = ['TransformerCRFForWordSegmentation', 'LSTMCRFForWordSegmentation'] | |||
class SequenceLabelingForWordSegmentation(TorchModel): | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model = self.init_model(model_dir, *args, **kwargs) | |||
model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) | |||
self.model.load_state_dict( | |||
torch.load(model_ckpt, map_location=torch.device('cpu'))) | |||
def init_model(self, model_dir, *args, **kwargs): | |||
raise NotImplementedError | |||
def train(self): | |||
return self.model.train() | |||
def eval(self): | |||
return self.model.eval() | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
input_tensor = { | |||
'input_ids': input['input_ids'], | |||
'attention_mask': input['attention_mask'], | |||
'label_mask': input['label_mask'], | |||
} | |||
output = { | |||
'offset_mapping': input['offset_mapping'], | |||
**input_tensor, | |||
**self.model(input_tensor) | |||
} | |||
return output | |||
def postprocess(self, input: Dict[str, Any], **kwargs): | |||
predicts = self.model.decode(input) | |||
offset_len = len(input['offset_mapping']) | |||
predictions = torch.narrow( | |||
predicts, 1, 0, | |||
offset_len) # index_select only move loc, not resize | |||
return TokenClassifierWithPredictionsOutput( | |||
loss=None, | |||
logits=None, | |||
hidden_states=None, | |||
attentions=None, | |||
offset_mapping=input['offset_mapping'], | |||
predictions=predictions, | |||
) | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.tcrf_wseg) | |||
class TransformerCRFForWordSegmentation(SequenceLabelingForWordSegmentation): | |||
"""This model wraps the TransformerCRF model to register into model sets. | |||
""" | |||
def init_model(self, model_dir, *args, **kwargs): | |||
self.config = AutoConfig.from_pretrained(model_dir) | |||
num_labels = self.config.num_labels | |||
model = TransformerCRF(model_dir, num_labels) | |||
return model | |||
@MODELS.register_module(Tasks.word_segmentation, module_name=Models.lcrf_wseg) | |||
class LSTMCRFForWordSegmentation(SequenceLabelingForWordSegmentation): | |||
"""This model wraps the LSTMCRF model to register into model sets. | |||
""" | |||
def init_model(self, model_dir, *args, **kwargs): | |||
self.config = AutoConfig.from_pretrained(model_dir) | |||
vocab_size = self.config.vocab_size | |||
embed_width = self.config.embed_width | |||
num_labels = self.config.num_labels | |||
lstm_hidden_size = self.config.lstm_hidden_size | |||
model = LSTMCRF(vocab_size, embed_width, num_labels, lstm_hidden_size) | |||
return model | |||
class TransformerCRF(nn.Module): | |||
"""A transformer based model to NER tasks. | |||
This model will use transformers' backbones as its backbone. | |||
""" | |||
def __init__(self, model_dir, num_labels, **kwargs): | |||
super(TransformerCRF, self).__init__() | |||
self.encoder = AutoModel.from_pretrained(model_dir) | |||
self.linear = nn.Linear(self.encoder.config.hidden_size, num_labels) | |||
self.crf = CRF(num_labels, batch_first=True) | |||
def forward(self, inputs): | |||
embed = self.encoder( | |||
inputs['input_ids'], attention_mask=inputs['attention_mask'])[0] | |||
logits = self.linear(embed) | |||
if 'label_mask' in inputs: | |||
mask = inputs['label_mask'] | |||
masked_lengths = mask.sum(-1).long() | |||
masked_logits = torch.zeros_like(logits) | |||
for i in range(len(mask)): | |||
masked_logits[ | |||
i, :masked_lengths[i], :] = logits[i].masked_select( | |||
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) | |||
logits = masked_logits | |||
outputs = {'logits': logits} | |||
return outputs | |||
def decode(self, inputs): | |||
seq_lens = inputs['label_mask'].sum(-1).long() | |||
mask = torch.arange( | |||
inputs['label_mask'].shape[1], | |||
device=seq_lens.device)[None, :] < seq_lens[:, None] | |||
predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) | |||
return predicts | |||
class LSTMCRF(nn.Module): | |||
""" | |||
A standard bilstm-crf model for fast prediction. | |||
""" | |||
def __init__(self, | |||
vocab_size, | |||
embed_width, | |||
num_labels, | |||
lstm_hidden_size=100, | |||
**kwargs): | |||
super(LSTMCRF, self).__init__() | |||
self.embedding = Embedding(vocab_size, embed_width) | |||
self.lstm = nn.LSTM( | |||
embed_width, | |||
lstm_hidden_size, | |||
num_layers=1, | |||
bidirectional=True, | |||
batch_first=True) | |||
self.ffn = nn.Linear(lstm_hidden_size * 2, num_labels) | |||
self.crf = CRF(num_labels, batch_first=True) | |||
def forward(self, inputs): | |||
embedding = self.embedding(inputs['input_ids']) | |||
lstm_output, _ = self.lstm(embedding) | |||
logits = self.ffn(lstm_output) | |||
if 'label_mask' in inputs: | |||
mask = inputs['label_mask'] | |||
masked_lengths = mask.sum(-1).long() | |||
masked_logits = torch.zeros_like(logits) | |||
for i in range(len(mask)): | |||
masked_logits[ | |||
i, :masked_lengths[i], :] = logits[i].masked_select( | |||
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1) | |||
logits = masked_logits | |||
outputs = {'logits': logits} | |||
return outputs | |||
def decode(self, inputs): | |||
seq_lens = inputs['label_mask'].sum(-1).long() | |||
mask = torch.arange( | |||
inputs['label_mask'].shape[1], | |||
device=seq_lens.device)[None, :] < seq_lens[:, None] | |||
predicts = self.crf.decode(inputs['logits'], mask=mask).squeeze(0) | |||
outputs = {'predicts': predicts} | |||
return outputs | |||
class CRF(nn.Module): | |||
"""Conditional random field. | |||
This module implements a conditional random field [LMP01]_. The forward computation | |||
of this class computes the log likelihood of the given sequence of tags and | |||
emission score tensor. This class also has `~CRF.decode` method which finds | |||
the best tag sequence given an emission score tensor using `Viterbi algorithm`_. | |||
Args: | |||
num_tags: Number of tags. | |||
batch_first: Whether the first dimension corresponds to the size of a minibatch. | |||
Attributes: | |||
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size | |||
``(num_tags,)``. | |||
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size | |||
``(num_tags,)``. | |||
transitions (`~torch.nn.Parameter`): Transition score tensor of size | |||
``(num_tags, num_tags)``. | |||
.. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). | |||
"Conditional random fields: Probabilistic models for segmenting and | |||
labeling sequence data". *Proc. 18th International Conf. on Machine | |||
Learning*. Morgan Kaufmann. pp. 282–289. | |||
.. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm | |||
""" | |||
def __init__(self, num_tags: int, batch_first: bool = False) -> None: | |||
if num_tags <= 0: | |||
raise ValueError(f'invalid number of tags: {num_tags}') | |||
super().__init__() | |||
self.num_tags = num_tags | |||
self.batch_first = batch_first | |||
self.start_transitions = nn.Parameter(torch.empty(num_tags)) | |||
self.end_transitions = nn.Parameter(torch.empty(num_tags)) | |||
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) | |||
self.reset_parameters() | |||
def reset_parameters(self) -> None: | |||
"""Initialize the transition parameters. | |||
The parameters will be initialized randomly from a uniform distribution | |||
between -0.1 and 0.1. | |||
""" | |||
nn.init.uniform_(self.start_transitions, -0.1, 0.1) | |||
nn.init.uniform_(self.end_transitions, -0.1, 0.1) | |||
nn.init.uniform_(self.transitions, -0.1, 0.1) | |||
def __repr__(self) -> str: | |||
return f'{self.__class__.__name__}(num_tags={self.num_tags})' | |||
def forward(self, | |||
emissions: torch.Tensor, | |||
tags: torch.LongTensor, | |||
mask: Optional[torch.ByteTensor] = None, | |||
reduction: str = 'mean') -> torch.Tensor: | |||
"""Compute the conditional log likelihood of a sequence of tags given emission scores. | |||
Args: | |||
emissions (`~torch.Tensor`): Emission score tensor of size | |||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, | |||
``(batch_size, seq_length, num_tags)`` otherwise. | |||
tags (`~torch.LongTensor`): Sequence of tags tensor of size | |||
``(seq_length, batch_size)`` if ``batch_first`` is ``False``, | |||
``(batch_size, seq_length)`` otherwise. | |||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` | |||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. | |||
reduction: Specifies the reduction to apply to the output: | |||
``none|sum|mean|token_mean``. ``none``: no reduction will be applied. | |||
``sum``: the output will be summed over batches. ``mean``: the output will be | |||
averaged over batches. ``token_mean``: the output will be averaged over tokens. | |||
Returns: | |||
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if | |||
reduction is ``none``, ``()`` otherwise. | |||
""" | |||
if reduction not in ('none', 'sum', 'mean', 'token_mean'): | |||
raise ValueError(f'invalid reduction: {reduction}') | |||
if mask is None: | |||
mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) | |||
if mask.dtype != torch.uint8: | |||
mask = mask.byte() | |||
self._validate(emissions, tags=tags, mask=mask) | |||
if self.batch_first: | |||
emissions = emissions.transpose(0, 1) | |||
tags = tags.transpose(0, 1) | |||
mask = mask.transpose(0, 1) | |||
# shape: (batch_size,) | |||
numerator = self._compute_score(emissions, tags, mask) | |||
# shape: (batch_size,) | |||
denominator = self._compute_normalizer(emissions, mask) | |||
# shape: (batch_size,) | |||
llh = numerator - denominator | |||
if reduction == 'none': | |||
return llh | |||
if reduction == 'sum': | |||
return llh.sum() | |||
if reduction == 'mean': | |||
return llh.mean() | |||
return llh.sum() / mask.float().sum() | |||
def decode(self, | |||
emissions: torch.Tensor, | |||
mask: Optional[torch.ByteTensor] = None, | |||
nbest: Optional[int] = None, | |||
pad_tag: Optional[int] = None) -> List[List[List[int]]]: | |||
"""Find the most likely tag sequence using Viterbi algorithm. | |||
Args: | |||
emissions (`~torch.Tensor`): Emission score tensor of size | |||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, | |||
``(batch_size, seq_length, num_tags)`` otherwise. | |||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` | |||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. | |||
nbest (`int`): Number of most probable paths for each sequence | |||
pad_tag (`int`): Tag at padded positions. Often input varies in length and | |||
the length will be padded to the maximum length in the batch. Tags at | |||
the padded positions will be assigned with a padding tag, i.e. `pad_tag` | |||
Returns: | |||
A PyTorch tensor of the best tag sequence for each batch of shape | |||
(nbest, batch_size, seq_length) | |||
""" | |||
if nbest is None: | |||
nbest = 1 | |||
if mask is None: | |||
mask = torch.ones( | |||
emissions.shape[:2], | |||
dtype=torch.uint8, | |||
device=emissions.device) | |||
if mask.dtype != torch.uint8: | |||
mask = mask.byte() | |||
self._validate(emissions, mask=mask) | |||
if self.batch_first: | |||
emissions = emissions.transpose(0, 1) | |||
mask = mask.transpose(0, 1) | |||
if nbest == 1: | |||
return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) | |||
return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) | |||
def _validate(self, | |||
emissions: torch.Tensor, | |||
tags: Optional[torch.LongTensor] = None, | |||
mask: Optional[torch.ByteTensor] = None) -> None: | |||
if emissions.dim() != 3: | |||
raise ValueError( | |||
f'emissions must have dimension of 3, got {emissions.dim()}') | |||
if emissions.size(2) != self.num_tags: | |||
raise ValueError( | |||
f'expected last dimension of emissions is {self.num_tags}, ' | |||
f'got {emissions.size(2)}') | |||
if tags is not None: | |||
if emissions.shape[:2] != tags.shape: | |||
raise ValueError( | |||
'the first two dimensions of emissions and tags must match, ' | |||
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}' | |||
) | |||
if mask is not None: | |||
if emissions.shape[:2] != mask.shape: | |||
raise ValueError( | |||
'the first two dimensions of emissions and mask must match, ' | |||
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}' | |||
) | |||
no_empty_seq = not self.batch_first and mask[0].all() | |||
no_empty_seq_bf = self.batch_first and mask[:, 0].all() | |||
if not no_empty_seq and not no_empty_seq_bf: | |||
raise ValueError('mask of the first timestep must all be on') | |||
def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, | |||
mask: torch.ByteTensor) -> torch.Tensor: | |||
# emissions: (seq_length, batch_size, num_tags) | |||
# tags: (seq_length, batch_size) | |||
# mask: (seq_length, batch_size) | |||
seq_length, batch_size = tags.shape | |||
mask = mask.float() | |||
# Start transition score and first emission | |||
# shape: (batch_size,) | |||
score = self.start_transitions[tags[0]] | |||
score += emissions[0, torch.arange(batch_size), tags[0]] | |||
for i in range(1, seq_length): | |||
# Transition score to next tag, only added if next timestep is valid (mask == 1) | |||
# shape: (batch_size,) | |||
score += self.transitions[tags[i - 1], tags[i]] * mask[i] | |||
# Emission score for next tag, only added if next timestep is valid (mask == 1) | |||
# shape: (batch_size,) | |||
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] | |||
# End transition score | |||
# shape: (batch_size,) | |||
seq_ends = mask.long().sum(dim=0) - 1 | |||
# shape: (batch_size,) | |||
last_tags = tags[seq_ends, torch.arange(batch_size)] | |||
# shape: (batch_size,) | |||
score += self.end_transitions[last_tags] | |||
return score | |||
def _compute_normalizer(self, emissions: torch.Tensor, | |||
mask: torch.ByteTensor) -> torch.Tensor: | |||
# emissions: (seq_length, batch_size, num_tags) | |||
# mask: (seq_length, batch_size) | |||
seq_length = emissions.size(0) | |||
# Start transition score and first emission; score has size of | |||
# (batch_size, num_tags) where for each batch, the j-th column stores | |||
# the score that the first timestep has tag j | |||
# shape: (batch_size, num_tags) | |||
score = self.start_transitions + emissions[0] | |||
for i in range(1, seq_length): | |||
# Broadcast score for every possible next tag | |||
# shape: (batch_size, num_tags, 1) | |||
broadcast_score = score.unsqueeze(2) | |||
# Broadcast emission score for every possible current tag | |||
# shape: (batch_size, 1, num_tags) | |||
broadcast_emissions = emissions[i].unsqueeze(1) | |||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where | |||
# for each sample, entry at row i and column j stores the sum of scores of all | |||
# possible tag sequences so far that end with transitioning from tag i to tag j | |||
# and emitting | |||
# shape: (batch_size, num_tags, num_tags) | |||
next_score = broadcast_score + self.transitions + broadcast_emissions | |||
# Sum over all possible current tags, but we're in score space, so a sum | |||
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of | |||
# all possible tag sequences so far, that end in tag i | |||
# shape: (batch_size, num_tags) | |||
next_score = torch.logsumexp(next_score, dim=1) | |||
# Set score to the next score if this timestep is valid (mask == 1) | |||
# shape: (batch_size, num_tags) | |||
score = torch.where(mask[i].unsqueeze(1), next_score, score) | |||
# End transition score | |||
# shape: (batch_size, num_tags) | |||
score += self.end_transitions | |||
# Sum (log-sum-exp) over all possible tags | |||
# shape: (batch_size,) | |||
return torch.logsumexp(score, dim=1) | |||
def _viterbi_decode(self, | |||
emissions: torch.FloatTensor, | |||
mask: torch.ByteTensor, | |||
pad_tag: Optional[int] = None) -> List[List[int]]: | |||
# emissions: (seq_length, batch_size, num_tags) | |||
# mask: (seq_length, batch_size) | |||
# return: (batch_size, seq_length) | |||
if pad_tag is None: | |||
pad_tag = 0 | |||
device = emissions.device | |||
seq_length, batch_size = mask.shape | |||
# Start transition and first emission | |||
# shape: (batch_size, num_tags) | |||
score = self.start_transitions + emissions[0] | |||
history_idx = torch.zeros((seq_length, batch_size, self.num_tags), | |||
dtype=torch.long, | |||
device=device) | |||
oor_idx = torch.zeros((batch_size, self.num_tags), | |||
dtype=torch.long, | |||
device=device) | |||
oor_tag = torch.full((seq_length, batch_size), | |||
pad_tag, | |||
dtype=torch.long, | |||
device=device) | |||
# - score is a tensor of size (batch_size, num_tags) where for every batch, | |||
# value at column j stores the score of the best tag sequence so far that ends | |||
# with tag j | |||
# - history_idx saves where the best tags candidate transitioned from; this is used | |||
# when we trace back the best tag sequence | |||
# - oor_idx saves the best tags candidate transitioned from at the positions | |||
# where mask is 0, i.e. out of range (oor) | |||
# Viterbi algorithm recursive case: we compute the score of the best tag sequence | |||
# for every possible next tag | |||
for i in range(1, seq_length): | |||
# Broadcast viterbi score for every possible next tag | |||
# shape: (batch_size, num_tags, 1) | |||
broadcast_score = score.unsqueeze(2) | |||
# Broadcast emission score for every possible current tag | |||
# shape: (batch_size, 1, num_tags) | |||
broadcast_emission = emissions[i].unsqueeze(1) | |||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where | |||
# for each sample, entry at row i and column j stores the score of the best | |||
# tag sequence so far that ends with transitioning from tag i to tag j and emitting | |||
# shape: (batch_size, num_tags, num_tags) | |||
next_score = broadcast_score + self.transitions + broadcast_emission | |||
# Find the maximum score over all possible current tag | |||
# shape: (batch_size, num_tags) | |||
next_score, indices = next_score.max(dim=1) | |||
# Set score to the next score if this timestep is valid (mask == 1) | |||
# and save the index that produces the next score | |||
# shape: (batch_size, num_tags) | |||
score = torch.where(mask[i].unsqueeze(-1), next_score, score) | |||
indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) | |||
history_idx[i - 1] = indices | |||
# End transition score | |||
# shape: (batch_size, num_tags) | |||
end_score = score + self.end_transitions | |||
_, end_tag = end_score.max(dim=1) | |||
# shape: (batch_size,) | |||
seq_ends = mask.long().sum(dim=0) - 1 | |||
# insert the best tag at each sequence end (last position with mask == 1) | |||
history_idx = history_idx.transpose(1, 0).contiguous() | |||
history_idx.scatter_( | |||
1, | |||
seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), | |||
end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) | |||
history_idx = history_idx.transpose(1, 0).contiguous() | |||
# The most probable path for each sequence | |||
best_tags_arr = torch.zeros((seq_length, batch_size), | |||
dtype=torch.long, | |||
device=device) | |||
best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) | |||
for idx in range(seq_length - 1, -1, -1): | |||
best_tags = torch.gather(history_idx[idx], 1, best_tags) | |||
best_tags_arr[idx] = best_tags.data.view(batch_size) | |||
return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) | |||
def _viterbi_decode_nbest( | |||
self, | |||
emissions: torch.FloatTensor, | |||
mask: torch.ByteTensor, | |||
nbest: int, | |||
pad_tag: Optional[int] = None) -> List[List[List[int]]]: | |||
# emissions: (seq_length, batch_size, num_tags) | |||
# mask: (seq_length, batch_size) | |||
# return: (nbest, batch_size, seq_length) | |||
if pad_tag is None: | |||
pad_tag = 0 | |||
device = emissions.device | |||
seq_length, batch_size = mask.shape | |||
# Start transition and first emission | |||
# shape: (batch_size, num_tags) | |||
score = self.start_transitions + emissions[0] | |||
history_idx = torch.zeros( | |||
(seq_length, batch_size, self.num_tags, nbest), | |||
dtype=torch.long, | |||
device=device) | |||
oor_idx = torch.zeros((batch_size, self.num_tags, nbest), | |||
dtype=torch.long, | |||
device=device) | |||
oor_tag = torch.full((seq_length, batch_size, nbest), | |||
pad_tag, | |||
dtype=torch.long, | |||
device=device) | |||
# + score is a tensor of size (batch_size, num_tags) where for every batch, | |||
# value at column j stores the score of the best tag sequence so far that ends | |||
# with tag j | |||
# + history_idx saves where the best tags candidate transitioned from; this is used | |||
# when we trace back the best tag sequence | |||
# - oor_idx saves the best tags candidate transitioned from at the positions | |||
# where mask is 0, i.e. out of range (oor) | |||
# Viterbi algorithm recursive case: we compute the score of the best tag sequence | |||
# for every possible next tag | |||
for i in range(1, seq_length): | |||
if i == 1: | |||
broadcast_score = score.unsqueeze(-1) | |||
broadcast_emission = emissions[i].unsqueeze(1) | |||
# shape: (batch_size, num_tags, num_tags) | |||
next_score = broadcast_score + self.transitions + broadcast_emission | |||
else: | |||
broadcast_score = score.unsqueeze(-1) | |||
broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) | |||
# shape: (batch_size, num_tags, nbest, num_tags) | |||
next_score = broadcast_score + self.transitions.unsqueeze( | |||
1) + broadcast_emission | |||
# Find the top `nbest` maximum score over all possible current tag | |||
# shape: (batch_size, nbest, num_tags) | |||
next_score, indices = next_score.view(batch_size, -1, | |||
self.num_tags).topk( | |||
nbest, dim=1) | |||
if i == 1: | |||
score = score.unsqueeze(-1).expand(-1, -1, nbest) | |||
indices = indices * nbest | |||
# convert to shape: (batch_size, num_tags, nbest) | |||
next_score = next_score.transpose(2, 1) | |||
indices = indices.transpose(2, 1) | |||
# Set score to the next score if this timestep is valid (mask == 1) | |||
# and save the index that produces the next score | |||
# shape: (batch_size, num_tags, nbest) | |||
score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), | |||
next_score, score) | |||
indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, | |||
oor_idx) | |||
history_idx[i - 1] = indices | |||
# End transition score shape: (batch_size, num_tags, nbest) | |||
end_score = score + self.end_transitions.unsqueeze(-1) | |||
_, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) | |||
# shape: (batch_size,) | |||
seq_ends = mask.long().sum(dim=0) - 1 | |||
# insert the best tag at each sequence end (last position with mask == 1) | |||
history_idx = history_idx.transpose(1, 0).contiguous() | |||
history_idx.scatter_( | |||
1, | |||
seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), | |||
end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) | |||
history_idx = history_idx.transpose(1, 0).contiguous() | |||
# The most probable path for each sequence | |||
best_tags_arr = torch.zeros((seq_length, batch_size, nbest), | |||
dtype=torch.long, | |||
device=device) | |||
best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ | |||
.view(1, -1).expand(batch_size, -1) | |||
for idx in range(seq_length - 1, -1, -1): | |||
best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, | |||
best_tags) | |||
best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest | |||
return torch.where(mask.unsqueeze(-1), best_tags_arr, | |||
oor_tag).permute(2, 1, 0) | |||
class Embedding(nn.Module): | |||
def __init__(self, vocab_size, embed_width): | |||
super(Embedding, self).__init__() | |||
self.embedding = nn.Embedding(vocab_size, embed_width) | |||
def forward(self, input_ids): | |||
return self.embedding(input_ids) |
@@ -4,7 +4,7 @@ from typing import Any, Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import TaskModels | |||
from modelscope.metainfo import Models, TaskModels | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.task_models.task_model import \ | |||
SingleBackboneTaskModelBase | |||
@@ -21,6 +21,9 @@ __all__ = ['TokenClassificationModel'] | |||
Tasks.token_classification, module_name=TaskModels.token_classification) | |||
@MODELS.register_module( | |||
Tasks.part_of_speech, module_name=TaskModels.token_classification) | |||
@MODELS.register_module( | |||
Tasks.named_entity_recognition, | |||
module_name=Models.token_classification_for_ner) | |||
class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
@@ -59,6 +62,9 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
if labels in input: | |||
loss = self.compute_loss(outputs, labels) | |||
# apply label mask to logits | |||
logits = logits[input['label_mask']].unsqueeze(0) | |||
return TokenClassifierOutput( | |||
loss=loss, | |||
logits=logits, | |||
@@ -490,7 +490,10 @@ TASK_OUTPUTS = { | |||
# word segmentation result for single sample | |||
# { | |||
# "output": "今天 天气 不错 , 适合 出去 游玩" | |||
# "output": ["今天", "天气", "不错", ",", "适合", "出去", "游玩"] | |||
# } | |||
# { | |||
# 'output': ['รถ', 'คัน', 'เก่า', 'ก็', 'ยัง', 'เก็บ', 'เอา'] | |||
# } | |||
Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||
@@ -29,11 +29,9 @@ if TYPE_CHECKING: | |||
from .text2text_generation_pipeline import Text2TextGenerationPipeline | |||
from .token_classification_pipeline import TokenClassificationPipeline | |||
from .translation_pipeline import TranslationPipeline | |||
from .word_segmentation_pipeline import WordSegmentationPipeline | |||
from .word_segmentation_pipeline import WordSegmentationPipeline, WordSegmentationThaiPipeline | |||
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||
from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | |||
WordSegmentationThaiPipeline | |||
else: | |||
_import_structure = { | |||
@@ -69,14 +67,11 @@ else: | |||
'translation_pipeline': ['TranslationPipeline'], | |||
'translation_quality_estimation_pipeline': | |||
['TranslationQualityEstimationPipeline'], | |||
'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
'word_segmentation_pipeline': | |||
['WordSegmentationPipeline', 'WordSegmentationThaiPipeline'], | |||
'zero_shot_classification_pipeline': | |||
['ZeroShotClassificationPipeline'], | |||
'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | |||
'multilingual_word_segmentation_pipeline': [ | |||
'MultilingualWordSegmentationPipeline', | |||
'WordSegmentationThaiPipeline' | |||
], | |||
} | |||
import sys | |||
@@ -1,125 +0,0 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Optional, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import (Preprocessor, | |||
TokenClassificationPreprocessor, | |||
WordSegmentationPreprocessorThai) | |||
from modelscope.utils.constant import Tasks | |||
__all__ = [ | |||
'MultilingualWordSegmentationPipeline', 'WordSegmentationThaiPipeline' | |||
] | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, | |||
module_name=Pipelines.multilingual_word_segmentation) | |||
class MultilingualWordSegmentationPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""Use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction | |||
Args: | |||
model (str or Model): Supply either a local model dir which supported word segmentation task, or a | |||
model id from the model hub, or a torch model instance. | |||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
the model if supplied. | |||
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. | |||
To view other examples plese check the tests/pipelines/test_multilingual_word_segmentation.py. | |||
""" | |||
model = model if isinstance(model, | |||
Model) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = TokenClassificationPreprocessor( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 512)) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = preprocessor.tokenizer | |||
self.config = model.config | |||
assert len(self.config.id2label) > 0 | |||
self.id2label = self.config.id2label | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
text = inputs.pop(OutputKeys.TEXT) | |||
with torch.no_grad(): | |||
return { | |||
**super().forward(inputs, **forward_params), OutputKeys.TEXT: | |||
text | |||
} | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
text = inputs['text'] | |||
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | |||
labels = [ | |||
self.id2label[x] | |||
for x in inputs['predictions'].squeeze(0).cpu().numpy() | |||
] | |||
entities = [] | |||
entity = {} | |||
for label, offsets in zip(labels, offset_mapping): | |||
if label[0] in 'BS': | |||
if entity: | |||
entity['span'] = text[entity['start']:entity['end']] | |||
entities.append(entity) | |||
entity = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'IES': | |||
if entity: | |||
entity['end'] = offsets[1] | |||
if label[0] in 'ES': | |||
if entity: | |||
entity['span'] = text[entity['start']:entity['end']] | |||
entities.append(entity) | |||
entity = {} | |||
if entity: | |||
entity['span'] = text[entity['start']:entity['end']] | |||
entities.append(entity) | |||
word_segments = [entity['span'] for entity in entities] | |||
outputs = {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} | |||
return outputs | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) | |||
class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
model = model if isinstance(model, | |||
Model) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = WordSegmentationPreprocessorThai( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 512)) | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
outputs = super().postprocess(inputs, **postprocess_params) | |||
word_segments = outputs[OutputKeys.OUTPUT] | |||
word_segments = [seg.replace(' ', '') for seg in word_segments] | |||
return {OutputKeys.OUTPUT: word_segments, OutputKeys.LABELS: []} |
@@ -9,6 +9,7 @@ from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.pipelines.nlp import TokenClassificationPipeline | |||
from modelscope.preprocessors import (NERPreprocessorThai, NERPreprocessorViet, | |||
Preprocessor, | |||
TokenClassificationPreprocessor) | |||
@@ -25,7 +26,7 @@ __all__ = [ | |||
@PIPELINES.register_module( | |||
Tasks.named_entity_recognition, | |||
module_name=Pipelines.named_entity_recognition) | |||
class NamedEntityRecognitionPipeline(Pipeline): | |||
class NamedEntityRecognitionPipeline(TokenClassificationPipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
@@ -55,97 +56,12 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||
if preprocessor is None: | |||
preprocessor = TokenClassificationPreprocessor( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 512)) | |||
sequence_length=kwargs.pop('sequence_length', 128)) | |||
model.eval() | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.tokenizer = preprocessor.tokenizer | |||
self.config = model.config | |||
assert len(self.config.id2label) > 0 | |||
self.id2label = self.config.id2label | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
text = inputs.pop(OutputKeys.TEXT) | |||
with torch.no_grad(): | |||
return { | |||
**self.model(**inputs, **forward_params), OutputKeys.TEXT: text | |||
} | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): should be tensors from model | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
text = inputs['text'] | |||
if OutputKeys.PREDICTIONS not in inputs: | |||
logits = inputs[OutputKeys.LOGITS] | |||
predictions = torch.argmax(logits[0], dim=-1) | |||
else: | |||
predictions = inputs[OutputKeys.PREDICTIONS].squeeze( | |||
0).cpu().numpy() | |||
predictions = torch_nested_numpify(torch_nested_detach(predictions)) | |||
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | |||
labels = [self.id2label[x] for x in predictions] | |||
if len(labels) > len(offset_mapping): | |||
labels = labels[1:-1] | |||
chunks = [] | |||
chunk = {} | |||
for label, offsets in zip(labels, offset_mapping): | |||
if label[0] in 'BS': | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'I': | |||
if not chunk: | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'E': | |||
if not chunk: | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'IES': | |||
if chunk: | |||
chunk['end'] = offsets[1] | |||
if label[0] in 'ES': | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
chunk = {} | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = ' '.join(spans) | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outputs | |||
else: | |||
outputs = {OutputKeys.OUTPUT: chunks} | |||
return outputs | |||
self.id2label = kwargs.get('id2label') | |||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | |||
self.id2label = self.preprocessor.id2label | |||
@PIPELINES.register_module( | |||
@@ -117,7 +117,12 @@ class TextClassificationPipeline(Pipeline): | |||
probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() | |||
def map_to_label(id): | |||
return self.id2label[id] | |||
if id in self.id2label: | |||
return self.id2label[id] | |||
elif str(id) in self.id2label: | |||
return self.id2label[str(id)] | |||
else: | |||
raise Exception('id not found in id2label') | |||
v_func = np.vectorize(map_to_label) | |||
return { | |||
@@ -64,6 +64,31 @@ class TokenClassificationPipeline(Pipeline): | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): should be tensors from model | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
chunks = self._chunk_process(inputs, **postprocess_params) | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'].lower() == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = [span for span in spans] | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outputs | |||
else: | |||
outputs = {OutputKeys.OUTPUT: chunks} | |||
return outputs | |||
def _chunk_process(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results and output as chunks | |||
Args: | |||
inputs (Dict[str, Any]): should be tensors from model | |||
@@ -71,7 +96,7 @@ class TokenClassificationPipeline(Pipeline): | |||
Dict[str, str]: the prediction results | |||
""" | |||
text = inputs['text'] | |||
if not hasattr(inputs, 'predictions'): | |||
if OutputKeys.PREDICTIONS not in inputs: | |||
logits = inputs[OutputKeys.LOGITS] | |||
predictions = torch.argmax(logits[0], dim=-1) | |||
else: | |||
@@ -123,15 +148,4 @@ class TokenClassificationPipeline(Pipeline): | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = ' '.join(spans) | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outputs | |||
else: | |||
outputs = {OutputKeys.OUTPUT: chunks} | |||
return outputs | |||
return chunks |
@@ -9,18 +9,20 @@ from modelscope.models import Model | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.pipelines.nlp import TokenClassificationPipeline | |||
from modelscope.preprocessors import (Preprocessor, | |||
TokenClassificationPreprocessor) | |||
TokenClassificationPreprocessor, | |||
WordSegmentationPreprocessorThai) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
torch_nested_numpify) | |||
__all__ = ['WordSegmentationPipeline'] | |||
__all__ = ['WordSegmentationPipeline', 'WordSegmentationThaiPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, module_name=Pipelines.word_segmentation) | |||
class WordSegmentationPipeline(Pipeline): | |||
class WordSegmentationPipeline(TokenClassificationPipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
@@ -58,89 +60,38 @@ class WordSegmentationPipeline(Pipeline): | |||
self.id2label = kwargs.get('id2label') | |||
if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | |||
self.id2label = self.preprocessor.id2label | |||
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
'as a parameter or make sure the preprocessor has the attribute.' | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
text = inputs.pop(OutputKeys.TEXT) | |||
with torch.no_grad(): | |||
return { | |||
**self.model(**inputs, **forward_params), OutputKeys.TEXT: text | |||
} | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, | |||
module_name=Pipelines.multilingual_word_segmentation) | |||
class MultilingualWordSegmentationPipeline(WordSegmentationPipeline): | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
chunks = self._chunk_process(inputs, **postprocess_params) | |||
word_segments = [entity['span'] for entity in chunks] | |||
return {OutputKeys.OUTPUT: word_segments} | |||
Args: | |||
inputs (Dict[str, Any]): should be tensors from model | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
text = inputs['text'] | |||
if not hasattr(inputs, 'predictions'): | |||
logits = inputs[OutputKeys.LOGITS] | |||
predictions = torch.argmax(logits[0], dim=-1) | |||
else: | |||
predictions = inputs[OutputKeys.PREDICTIONS].squeeze( | |||
0).cpu().numpy() | |||
predictions = torch_nested_numpify(torch_nested_detach(predictions)) | |||
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | |||
labels = [self.id2label[x] for x in predictions] | |||
if len(labels) > len(offset_mapping): | |||
labels = labels[1:-1] | |||
chunks = [] | |||
chunk = {} | |||
for label, offsets in zip(labels, offset_mapping): | |||
if label[0] in 'BS': | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'I': | |||
if not chunk: | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'E': | |||
if not chunk: | |||
chunk = { | |||
'type': label[2:], | |||
'start': offsets[0], | |||
'end': offsets[1] | |||
} | |||
if label[0] in 'IES': | |||
if chunk: | |||
chunk['end'] = offsets[1] | |||
if label[0] in 'ES': | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
chunk = {} | |||
if chunk: | |||
chunk['span'] = text[chunk['start']:chunk['end']] | |||
chunks.append(chunk) | |||
# for cws outputs | |||
if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
spans = [ | |||
chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
] | |||
seg_result = ' '.join(spans) | |||
outputs = {OutputKeys.OUTPUT: seg_result} | |||
# for ner outputs | |||
else: | |||
outputs = {OutputKeys.OUTPUT: chunks} | |||
return outputs | |||
@PIPELINES.register_module( | |||
Tasks.word_segmentation, module_name=Pipelines.word_segmentation_thai) | |||
class WordSegmentationThaiPipeline(MultilingualWordSegmentationPipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
model = model if isinstance(model, | |||
Model) else Model.from_pretrained(model) | |||
if preprocessor is None: | |||
preprocessor = WordSegmentationPreprocessorThai( | |||
model.model_dir, | |||
sequence_length=kwargs.pop('sequence_length', 512)) | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
chunks = self._chunk_process(inputs, **postprocess_params) | |||
word_segments = [entity['span'].replace(' ', '') for entity in chunks] | |||
return {OutputKeys.OUTPUT: word_segments} |
@@ -154,4 +154,6 @@ def parse_label_mapping(model_dir): | |||
elif hasattr(config, 'id2label'): | |||
id2label = config.id2label | |||
label2id = {label: id for id, label in id2label.items()} | |||
if label2id is not None: | |||
label2id = {label: int(id) for label, id in label2id.items()} | |||
return label2id |
@@ -0,0 +1,45 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import SbertForSequenceClassification | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.nlp import TextClassificationPipeline | |||
from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.regress_test_utils import IgnoreKeyFn, MsRegressTool | |||
from modelscope.utils.test_utils import test_level | |||
class AddrSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): | |||
sentence1 = '阿里巴巴西溪园区' | |||
sentence2 = '文一西路阿里巴巴' | |||
model_id = 'damo/nlp_structbert_address-matching_chinese_base' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
preprocessor = SequenceClassificationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.text_classification, | |||
model=model, | |||
preprocessor=preprocessor) | |||
print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.text_classification, model=self.model_id) | |||
print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -23,9 +23,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
chinese_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-large-generic' | |||
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news' | |||
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news' | |||
addr_model_id = 'damo/nlp_structbert_address-parsing_chinese_base' | |||
sentence = '这与温岭市新河镇的一个神秘的传说有关。' | |||
sentence_en = 'pizza shovel' | |||
sentence_zh = '他 继 续 与 貝 塞 斯 達 遊 戲 工 作 室 在 接 下 来 辐 射 4 游 戏 。' | |||
addr = '浙江省杭州市余杭区文一西路969号亲橙里' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_tcrf_by_direct_model_download(self): | |||
@@ -71,6 +73,23 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
preprocessor=tokenizer) | |||
print(pipeline_ins(input=self.sentence)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_addrst_with_model_from_modelhub(self): | |||
model = Model.from_pretrained( | |||
'damo/nlp_structbert_address-parsing_chinese_base') | |||
tokenizer = TokenClassificationPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.named_entity_recognition, | |||
model=model, | |||
preprocessor=tokenizer) | |||
print(pipeline_ins(input=self.addr)) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_addrst_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.named_entity_recognition, model=self.addr_model_id) | |||
print(pipeline_ins(input=self.addr)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_lcrf_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.lcrf_model_id) | |||