@@ -3,3 +3,4 @@ from .nli_model import * # noqa F403 | |||||
from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
from .sbert_for_sentence_similarity import * # noqa F403 | from .sbert_for_sentence_similarity import * # noqa F403 | ||||
from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
from .sentiment_classification_model import * # noqa F403 |
@@ -0,0 +1,85 @@ | |||||
import os | |||||
from typing import Any, Dict | |||||
import numpy as np | |||||
import torch | |||||
from sofa import SbertConfig, SbertModel | |||||
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||||
from torch import nn | |||||
from transformers.activations import ACT2FN, get_activation | |||||
from transformers.models.bert.modeling_bert import SequenceClassifierOutput | |||||
from modelscope.utils.constant import Tasks | |||||
from ..base import Model, Tensor | |||||
from ..builder import MODELS | |||||
__all__ = ['SbertForSentimentClassification'] | |||||
class SbertTextClassifier(SbertPreTrainedModel): | |||||
def __init__(self, config): | |||||
super().__init__(config) | |||||
self.num_labels = config.num_labels | |||||
self.config = config | |||||
self.encoder = SbertModel(config, add_pooling_layer=True) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||||
def forward(self, input_ids=None, token_type_ids=None): | |||||
outputs = self.encoder( | |||||
input_ids, | |||||
token_type_ids=token_type_ids, | |||||
return_dict=None, | |||||
) | |||||
pooled_output = outputs[1] | |||||
pooled_output = self.dropout(pooled_output) | |||||
logits = self.classifier(pooled_output) | |||||
return logits | |||||
@MODELS.register_module( | |||||
Tasks.sentiment_classification, | |||||
module_name=r'sbert-sentiment-classification') | |||||
class SbertForSentimentClassification(Model): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
"""initialize the text generation model from the `model_dir` path. | |||||
Args: | |||||
model_dir (str): the model path. | |||||
model_cls (Optional[Any], optional): model loader, if None, use the | |||||
default loader to load model weights, by default None. | |||||
""" | |||||
super().__init__(model_dir, *args, **kwargs) | |||||
self.model_dir = model_dir | |||||
self.model = SbertTextClassifier.from_pretrained( | |||||
model_dir, num_labels=2) | |||||
self.model.eval() | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
"""return the result by the model | |||||
Args: | |||||
input (Dict[str, Any]): the preprocessed data | |||||
Returns: | |||||
Dict[str, np.ndarray]: results | |||||
Example: | |||||
{ | |||||
'predictions': array([1]), # lable 0-negative 1-positive | |||||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
} | |||||
""" | |||||
input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||||
token_type_ids = torch.tensor( | |||||
input['token_type_ids'], dtype=torch.long) | |||||
with torch.no_grad(): | |||||
logits = self.model(input_ids, token_type_ids) | |||||
probs = logits.softmax(-1).numpy() | |||||
pred = logits.argmax(-1).numpy() | |||||
logits = logits.numpy() | |||||
res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||||
return res |
@@ -22,8 +22,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | ||||
Tasks.nli: ('nlp_structbert_nli_chinese-base', | Tasks.nli: ('nlp_structbert_nli_chinese-base', | ||||
'damo/nlp_structbert_nli_chinese-base'), | 'damo/nlp_structbert_nli_chinese-base'), | ||||
Tasks.text_classification: | |||||
('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||||
Tasks.sentiment_classification: | |||||
('sbert-sentiment-classification', | |||||
'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||||
Tasks.text_classification: ('bert-sentiment-analysis', | |||||
'damo/bert-base-sst2'), | |||||
Tasks.text_generation: ('palm2.0', | Tasks.text_generation: ('palm2.0', | ||||
'damo/nlp_palm2.0_text-generation_chinese-base'), | 'damo/nlp_palm2.0_text-generation_chinese-base'), | ||||
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), | Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), | ||||
@@ -1,5 +1,6 @@ | |||||
from .nli_pipeline import * # noqa F403 | from .nli_pipeline import * # noqa F403 | ||||
from .sentence_similarity_pipeline import * # noqa F403 | from .sentence_similarity_pipeline import * # noqa F403 | ||||
from .sentiment_classification_pipeline import * # noqa F403 | |||||
from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
from .word_segmentation_pipeline import * # noqa F403 | from .word_segmentation_pipeline import * # noqa F403 |
@@ -0,0 +1,90 @@ | |||||
import os | |||||
import uuid | |||||
from typing import Any, Dict, Union | |||||
import json | |||||
import numpy as np | |||||
from modelscope.models.nlp import SbertForSentimentClassification | |||||
from modelscope.preprocessors import SentimentClassificationPreprocessor | |||||
from modelscope.utils.constant import Tasks | |||||
from ...models import Model | |||||
from ..base import Input, Pipeline | |||||
from ..builder import PIPELINES | |||||
__all__ = ['SentimentClassificationPipeline'] | |||||
@PIPELINES.register_module( | |||||
Tasks.sentiment_classification, | |||||
module_name=r'sbert-sentiment-classification') | |||||
class SentimentClassificationPipeline(Pipeline): | |||||
def __init__(self, | |||||
model: Union[SbertForSentimentClassification, str], | |||||
preprocessor: SentimentClassificationPreprocessor = None, | |||||
**kwargs): | |||||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
Args: | |||||
model (SbertForSentimentClassification): a model instance | |||||
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||||
""" | |||||
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \ | |||||
'model must be a single str or SbertForSentimentClassification' | |||||
sc_model = model if isinstance( | |||||
model, | |||||
SbertForSentimentClassification) else Model.from_pretrained(model) | |||||
if preprocessor is None: | |||||
preprocessor = SentimentClassificationPreprocessor( | |||||
sc_model.model_dir, | |||||
first_sequence='first_sequence', | |||||
second_sequence='second_sequence') | |||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||||
self.label_path = os.path.join(sc_model.model_dir, | |||||
'label_mapping.json') | |||||
with open(self.label_path) as f: | |||||
self.label_mapping = json.load(f) | |||||
self.label_id_to_name = { | |||||
idx: name | |||||
for name, idx in self.label_mapping.items() | |||||
} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
"""process the prediction results | |||||
Args: | |||||
inputs (Dict[str, Any]): _description_ | |||||
Returns: | |||||
Dict[str, str]: the prediction results | |||||
""" | |||||
probs = inputs['probabilities'] | |||||
logits = inputs['logits'] | |||||
predictions = np.argsort(-probs, axis=-1) | |||||
preds = predictions[0] | |||||
b = 0 | |||||
new_result = list() | |||||
for pred in preds: | |||||
new_result.append({ | |||||
'pred': self.label_id_to_name[pred], | |||||
'prob': float(probs[b][pred]), | |||||
'logit': float(logits[b][pred]) | |||||
}) | |||||
new_results = list() | |||||
new_results.append({ | |||||
'id': | |||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||||
'output': | |||||
new_result, | |||||
'predictions': | |||||
new_result[0]['pred'], | |||||
'probabilities': | |||||
','.join([str(t) for t in inputs['probabilities'][b]]), | |||||
'logits': | |||||
','.join([str(t) for t in inputs['logits'][b]]) | |||||
}) | |||||
return new_results[0] |
@@ -7,5 +7,4 @@ from .common import Compose | |||||
from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
from .multi_model import OfaImageCaptionPreprocessor | from .multi_model import OfaImageCaptionPreprocessor | ||||
from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
from .nlp import NLIPreprocessor, TextGenerationPreprocessor | |||||
from .text_to_speech import * # noqa F403 | from .text_to_speech import * # noqa F403 |
@@ -13,7 +13,7 @@ from .builder import PREPROCESSORS | |||||
__all__ = [ | __all__ = [ | ||||
'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | ||||
'NLIPreprocessor' | |||||
'NLIPreprocessor', 'SentimentClassificationPreprocessor' | |||||
] | ] | ||||
@@ -65,7 +65,6 @@ class NLIPreprocessor(Preprocessor): | |||||
sentence2 (str): a sentence | sentence2 (str): a sentence | ||||
Example: | Example: | ||||
'you are so beautiful.' | 'you are so beautiful.' | ||||
Returns: | Returns: | ||||
Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
""" | """ | ||||
@@ -102,6 +101,70 @@ class NLIPreprocessor(Preprocessor): | |||||
return rst | return rst | ||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=r'sbert-sentiment-classification') | |||||
class SentimentClassificationPreprocessor(Preprocessor): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||||
Args: | |||||
model_dir (str): model path | |||||
""" | |||||
super().__init__(*args, **kwargs) | |||||
from sofa import SbertTokenizer | |||||
self.model_dir: str = model_dir | |||||
self.first_sequence: str = kwargs.pop('first_sequence', | |||||
'first_sequence') | |||||
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||||
@type_assert(object, str) | |||||
def __call__(self, data: str) -> Dict[str, Any]: | |||||
"""process the raw input data | |||||
Args: | |||||
data (str): a sentence | |||||
Example: | |||||
'you are so handsome.' | |||||
Returns: | |||||
Dict[str, Any]: the preprocessed data | |||||
""" | |||||
new_data = {self.first_sequence: data} | |||||
# preprocess the data for the model input | |||||
rst = { | |||||
'id': [], | |||||
'input_ids': [], | |||||
'attention_mask': [], | |||||
'token_type_ids': [] | |||||
} | |||||
max_seq_length = self.sequence_length | |||||
text_a = new_data[self.first_sequence] | |||||
text_b = new_data.get(self.second_sequence, None) | |||||
feature = self.tokenizer( | |||||
text_a, | |||||
text_b, | |||||
padding='max_length', | |||||
truncation=True, | |||||
max_length=max_seq_length) | |||||
rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
rst['input_ids'].append(feature['input_ids']) | |||||
rst['attention_mask'].append(feature['attention_mask']) | |||||
rst['token_type_ids'].append(feature['token_type_ids']) | |||||
return rst | |||||
@PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
Fields.nlp, module_name=r'bert-sequence-classification') | Fields.nlp, module_name=r'bert-sequence-classification') | ||||
class SequenceClassificationPreprocessor(Preprocessor): | class SequenceClassificationPreprocessor(Preprocessor): | ||||
@@ -33,6 +33,7 @@ class Tasks(object): | |||||
# nlp tasks | # nlp tasks | ||||
word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
nli = 'nli' | nli = 'nli' | ||||
sentiment_classification = 'sentiment-classification' | |||||
sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
@@ -0,0 +1,54 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from maas_hub.snapshot_download import snapshot_download | |||||
from modelscope.models import Model | |||||
from modelscope.models.nlp import SbertForSentimentClassification | |||||
from modelscope.pipelines import SentimentClassificationPipeline, pipeline | |||||
from modelscope.preprocessors import SentimentClassificationPreprocessor | |||||
from modelscope.utils.constant import Tasks | |||||
class SentimentClassificationTest(unittest.TestCase): | |||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||||
sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' | |||||
def test_run_from_local(self): | |||||
cache_path = snapshot_download(self.model_id) | |||||
tokenizer = SentimentClassificationPreprocessor(cache_path) | |||||
model = SbertForSentimentClassification( | |||||
cache_path, tokenizer=tokenizer) | |||||
pipeline1 = SentimentClassificationPipeline( | |||||
model, preprocessor=tokenizer) | |||||
pipeline2 = pipeline( | |||||
Tasks.sentiment_classification, | |||||
model=model, | |||||
preprocessor=tokenizer) | |||||
print(f'sentence1: {self.sentence1}\n' | |||||
f'pipeline1:{pipeline1(input=self.sentence1)}') | |||||
print() | |||||
print(f'sentence1: {self.sentence1}\n' | |||||
f'pipeline1: {pipeline2(input=self.sentence1)}') | |||||
def test_run_with_model_from_modelhub(self): | |||||
model = Model.from_pretrained(self.model_id) | |||||
tokenizer = SentimentClassificationPreprocessor(model.model_dir) | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.sentiment_classification, | |||||
model=model, | |||||
preprocessor=tokenizer) | |||||
print(pipeline_ins(input=self.sentence1)) | |||||
def test_run_with_model_name(self): | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.sentiment_classification, model=self.model_id) | |||||
print(pipeline_ins(input=self.sentence1)) | |||||
def test_run_with_default_model(self): | |||||
pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||||
print(pipeline_ins(input=self.sentence1)) | |||||
if __name__ == '__main__': | |||||
unittest.main() |