Maas新增FAQ问答模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9797053master
@@ -49,8 +49,8 @@ def handle_http_response(response, logger, cookies, model_id): | |||
except HTTPError: | |||
if cookies is None: # code in [403] and | |||
logger.error( | |||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \ | |||
Please login first.') | |||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
private. Please login first.') | |||
raise | |||
@@ -138,6 +138,7 @@ class Pipelines(object): | |||
dialog_state_tracking = 'dialog-state-tracking' | |||
zero_shot_classification = 'zero-shot-classification' | |||
text_error_correction = 'text-error-correction' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
# audio tasks | |||
@@ -220,6 +221,7 @@ class Preprocessors(object): | |||
text_error_correction = 'text-error-correction' | |||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
fill_mask = 'fill-mask' | |||
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
# audio preprocessor | |||
@@ -6,7 +6,8 @@ import torch | |||
def point_form(boxes: torch.Tensor) -> torch.Tensor: | |||
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. | |||
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \ | |||
ground truth data. | |||
Args: | |||
boxes: center-size default boxes from priorbox layers. | |||
@@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||
from .task_models.task_model import SingleBackboneTaskModelBase | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .gpt3 import GPT3ForTextGeneration | |||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
else: | |||
_import_structure = { | |||
@@ -44,6 +45,7 @@ else: | |||
'task_model': ['SingleBackboneTaskModelBase'], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'gpt3': ['GPT3ForTextGeneration'], | |||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] | |||
} | |||
import sys | |||
@@ -0,0 +1,249 @@ | |||
import math | |||
import os | |||
from collections import namedtuple | |||
from typing import Dict | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch import Tensor | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.nlp.structbert import SbertConfig, SbertModel | |||
from modelscope.models.nlp.task_models.task_model import BaseTaskModel | |||
from modelscope.utils.config import Config, ConfigFields | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
__all__ = ['SbertForFaqQuestionAnswering'] | |||
class SbertForFaqQuestionAnsweringBase(BaseTaskModel): | |||
"""base class for faq models | |||
""" | |||
def __init__(self, model_dir, *args, **kwargs): | |||
super(SbertForFaqQuestionAnsweringBase, | |||
self).__init__(model_dir, *args, **kwargs) | |||
backbone_cfg = SbertConfig.from_pretrained(model_dir) | |||
self.bert = SbertModel(backbone_cfg) | |||
model_config = Config.from_file( | |||
os.path.join(model_dir, | |||
ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) | |||
metric = model_config.get('metric', 'cosine') | |||
pooling_method = model_config.get('pooling', 'avg') | |||
Arg = namedtuple('args', [ | |||
'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' | |||
]) | |||
args = Arg( | |||
metrics=metric, | |||
proj_hidden_size=self.bert.config.hidden_size, | |||
hidden_size=self.bert.config.hidden_size, | |||
dropout=0.0, | |||
pooling=pooling_method) | |||
self.metrics_layer = MetricsLayer(args) | |||
self.pooling = PoolingLayer(args) | |||
def _get_onehot_labels(self, labels, support_size, num_cls): | |||
labels_ = labels.view(support_size, 1) | |||
target_oh = torch.zeros(support_size, num_cls).to(labels) | |||
target_oh.scatter_(dim=1, index=labels_, value=1) | |||
return target_oh.view(support_size, num_cls).float() | |||
def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): | |||
input_ids = inputs['input_ids'] | |||
input_mask = inputs['attention_mask'] | |||
if not isinstance(input_ids, Tensor): | |||
input_ids = torch.IntTensor(input_ids) | |||
if not isinstance(input_mask, Tensor): | |||
input_mask = torch.IntTensor(input_mask) | |||
rst = self.bert(input_ids, input_mask) | |||
last_hidden_states = rst.last_hidden_state | |||
if len(input_mask.shape) == 2: | |||
input_mask = input_mask.unsqueeze(-1) | |||
pooled_representation = self.pooling(last_hidden_states, input_mask) | |||
return pooled_representation | |||
@MODELS.register_module( | |||
Tasks.faq_question_answering, module_name=Models.structbert) | |||
class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): | |||
_backbone_prefix = '' | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
assert not self.training | |||
query = input['query'] | |||
support = input['support'] | |||
if isinstance(query, list): | |||
query = torch.stack(query) | |||
if isinstance(support, list): | |||
support = torch.stack(support) | |||
n_query = query.shape[0] | |||
n_support = support.shape[0] | |||
query_mask = torch.ne(query, 0).view([n_query, -1]) | |||
support_mask = torch.ne(support, 0).view([n_support, -1]) | |||
support_labels = input['support_labels'] | |||
num_cls = torch.max(support_labels) + 1 | |||
onehot_labels = self._get_onehot_labels(support_labels, n_support, | |||
num_cls) | |||
input_ids = torch.cat([query, support]) | |||
input_mask = torch.cat([query_mask, support_mask], dim=0) | |||
pooled_representation = self.forward_sentence_embedding({ | |||
'input_ids': | |||
input_ids, | |||
'attention_mask': | |||
input_mask | |||
}) | |||
z_query = pooled_representation[:n_query] | |||
z_support = pooled_representation[n_query:] | |||
cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 | |||
protos = torch.matmul(onehot_labels.transpose(0, 1), | |||
z_support) / cls_n_support.unsqueeze(-1) | |||
scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) | |||
if self.metrics_layer.name == 'relation': | |||
scores = torch.sigmoid(scores) | |||
return {'scores': scores} | |||
activations = { | |||
'relu': F.relu, | |||
'tanh': torch.tanh, | |||
'linear': lambda x: x, | |||
} | |||
activation_coeffs = { | |||
'relu': math.sqrt(2), | |||
'tanh': 5 / 3, | |||
'linear': 1., | |||
} | |||
class LinearProjection(nn.Module): | |||
def __init__(self, | |||
in_features, | |||
out_features, | |||
activation='linear', | |||
bias=True): | |||
super().__init__() | |||
self.activation = activations[activation] | |||
activation_coeff = activation_coeffs[activation] | |||
linear = nn.Linear(in_features, out_features, bias=bias) | |||
nn.init.normal_( | |||
linear.weight, std=math.sqrt(1. / in_features) * activation_coeff) | |||
if bias: | |||
nn.init.zeros_(linear.bias) | |||
self.model = nn.utils.weight_norm(linear) | |||
def forward(self, x): | |||
return self.activation(self.model(x)) | |||
class RelationModule(nn.Module): | |||
def __init__(self, args): | |||
super(RelationModule, self).__init__() | |||
input_size = args.proj_hidden_size * 4 | |||
self.prediction = torch.nn.Sequential( | |||
LinearProjection( | |||
input_size, args.proj_hidden_size * 4, activation='relu'), | |||
nn.Dropout(args.dropout), | |||
LinearProjection(args.proj_hidden_size * 4, 1)) | |||
def forward(self, query, protos): | |||
n_cls = protos.shape[0] | |||
n_query = query.shape[0] | |||
protos = protos.unsqueeze(0).repeat(n_query, 1, 1) | |||
query = query.unsqueeze(1).repeat(1, n_cls, 1) | |||
input_feat = torch.cat( | |||
[query, protos, (protos - query).abs(), query * protos], dim=-1) | |||
dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1] | |||
return dists.squeeze(-1) | |||
class MetricsLayer(nn.Module): | |||
def __init__(self, args): | |||
super(MetricsLayer, self).__init__() | |||
self.args = args | |||
assert args.metrics in ('relation', 'cosine') | |||
if args.metrics == 'relation': | |||
self.relation_net = RelationModule(args) | |||
@property | |||
def name(self): | |||
return self.args.metrics | |||
def forward(self, query, protos): | |||
""" query : [bsz, n_query, dim] | |||
support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim] | |||
""" | |||
if self.args.metrics == 'cosine': | |||
supervised_dists = self.cosine_similarity(query, protos) | |||
if self.training: | |||
supervised_dists *= 5 | |||
elif self.args.metrics in ('relation', ): | |||
supervised_dists = self.relation_net(query, protos) | |||
else: | |||
raise NotImplementedError | |||
return supervised_dists | |||
def cosine_similarity(self, x, y): | |||
# x=[bsz, n_query, dim] | |||
# y=[bsz, n_cls, dim] | |||
n_query = x.shape[0] | |||
n_cls = y.shape[0] | |||
dim = x.shape[-1] | |||
x = x.unsqueeze(1).expand([n_query, n_cls, dim]) | |||
y = y.unsqueeze(0).expand([n_query, n_cls, dim]) | |||
return F.cosine_similarity(x, y, -1) | |||
class AveragePooling(nn.Module): | |||
def forward(self, x, mask, dim=1): | |||
return torch.sum( | |||
x * mask.float(), dim=dim) / torch.sum( | |||
mask.float(), dim=dim) | |||
class AttnPooling(nn.Module): | |||
def __init__(self, input_size, hidden_size=None, output_size=None): | |||
super().__init__() | |||
self.input_proj = nn.Sequential( | |||
LinearProjection(input_size, hidden_size), nn.Tanh(), | |||
LinearProjection(hidden_size, 1, bias=False)) | |||
self.output_proj = LinearProjection( | |||
input_size, output_size) if output_size else lambda x: x | |||
def forward(self, x, mask): | |||
score = self.input_proj(x) | |||
score = score * mask.float() + -1e4 * (1. - mask.float()) | |||
score = F.softmax(score, dim=1) | |||
features = self.output_proj(x) | |||
return torch.matmul(score.transpose(1, 2), features).squeeze(1) | |||
class PoolingLayer(nn.Module): | |||
def __init__(self, args): | |||
super(PoolingLayer, self).__init__() | |||
if args.pooling == 'attn': | |||
self.pooling = AttnPooling(args.proj_hidden_size, | |||
args.proj_hidden_size, | |||
args.proj_hidden_size) | |||
elif args.pooling == 'avg': | |||
self.pooling = AveragePooling() | |||
else: | |||
raise NotImplementedError(args.pooling) | |||
def forward(self, x, mask): | |||
return self.pooling(x, mask) |
@@ -7,6 +7,7 @@ class OutputKeys(object): | |||
LOSS = 'loss' | |||
LOGITS = 'logits' | |||
SCORES = 'scores' | |||
SCORE = 'score' | |||
LABEL = 'label' | |||
LABELS = 'labels' | |||
INPUT_IDS = 'input_ids' | |||
@@ -504,6 +505,16 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
# { | |||
# 'output': [ | |||
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | |||
# {'label': '13421097', 'score': 2.2825044965202324e-08}], | |||
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||
# {'label': '13421097', 'score': 2.75914817393641e-06}], | |||
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||
# {'label': '13421097', 'score': 2.75914817393641e-06}]] | |||
# } | |||
Tasks.faq_question_answering: [OutputKeys.OUTPUT], | |||
# image person reid result for single sample | |||
# { | |||
# "img_embedding": np.array with shape [1, D], | |||
@@ -129,6 +129,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/cv_convnextTiny_ocr-recognition-general_damo'), | |||
Tasks.skin_retouching: (Pipelines.skin_retouching, | |||
'damo/cv_unet_skin-retouching'), | |||
Tasks.faq_question_answering: | |||
(Pipelines.faq_question_answering, | |||
'damo/nlp_structbert_faq-question-answering_chinese-base'), | |||
Tasks.crowd_counting: (Pipelines.crowd_counting, | |||
'damo/cv_hrnet_crowd-counting_dcanet'), | |||
Tasks.video_single_object_tracking: | |||
@@ -218,7 +221,6 @@ def pipeline(task: str = None, | |||
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
model = normalize_model_input(model, model_revision) | |||
if pipeline_name is None: | |||
# get default pipeline for this task | |||
if isinstance(model, str) \ | |||
@@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||
from .summarization_pipeline import SummarizationPipeline | |||
from .text_classification_pipeline import TextClassificationPipeline | |||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
else: | |||
_import_structure = { | |||
@@ -44,7 +45,8 @@ else: | |||
'translation_pipeline': ['TranslationPipeline'], | |||
'summarization_pipeline': ['SummarizationPipeline'], | |||
'text_classification_pipeline': ['TextClassificationPipeline'], | |||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] | |||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] | |||
} | |||
import sys | |||
@@ -0,0 +1,76 @@ | |||
from typing import Any, Dict, Union | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['FaqQuestionAnsweringPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering) | |||
class FaqQuestionAnsweringPipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[str, SbertForFaqQuestionAnswering], | |||
preprocessor: FaqQuestionAnsweringPreprocessor = None, | |||
**kwargs): | |||
model = model if isinstance( | |||
model, | |||
SbertForFaqQuestionAnswering) else Model.from_pretrained(model) | |||
model.eval() | |||
if preprocessor is None: | |||
preprocessor = FaqQuestionAnsweringPreprocessor( | |||
model.model_dir, **kwargs) | |||
self.preprocessor = preprocessor | |||
super(FaqQuestionAnsweringPipeline, self).__init__( | |||
model=model, preprocessor=preprocessor, **kwargs) | |||
def _sanitize_parameters(self, **pipeline_parameters): | |||
return pipeline_parameters, pipeline_parameters, pipeline_parameters | |||
def get_sentence_embedding(self, inputs, max_len=None): | |||
inputs = self.preprocessor.batch_encode(inputs, max_length=max_len) | |||
sentence_vecs = self.model.forward_sentence_embedding(inputs) | |||
sentence_vecs = sentence_vecs.detach().tolist() | |||
return sentence_vecs | |||
def forward(self, inputs: [list, Dict[str, Any]], | |||
**forward_params) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return self.model(inputs) | |||
def postprocess(self, inputs: [list, Dict[str, Any]], | |||
**postprocess_params) -> Dict[str, Any]: | |||
scores = inputs['scores'] | |||
labels = [] | |||
for item in scores: | |||
tmplabels = [ | |||
self.preprocessor.get_label(label_id) | |||
for label_id in range(len(item)) | |||
] | |||
labels.append(tmplabels) | |||
predictions = [] | |||
for tmp_scores, tmp_labels in zip(scores.tolist(), labels): | |||
prediction = [] | |||
for score, label in zip(tmp_scores, tmp_labels): | |||
prediction.append({ | |||
OutputKeys.LABEL: label, | |||
OutputKeys.SCORE: score | |||
}) | |||
predictions.append( | |||
list( | |||
sorted( | |||
prediction, | |||
key=lambda d: d[OutputKeys.SCORE], | |||
reverse=True))) | |||
return {OutputKeys.OUTPUT: predictions} |
@@ -21,7 +21,8 @@ if TYPE_CHECKING: | |||
SingleSentenceClassificationPreprocessor, | |||
PairSentenceClassificationPreprocessor, | |||
FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | |||
NERPreprocessor, TextErrorCorrectionPreprocessor) | |||
NERPreprocessor, TextErrorCorrectionPreprocessor, | |||
FaqQuestionAnsweringPreprocessor) | |||
from .space import (DialogIntentPredictionPreprocessor, | |||
DialogModelingPreprocessor, | |||
DialogStateTrackingPreprocessor) | |||
@@ -48,7 +49,8 @@ else: | |||
'SingleSentenceClassificationPreprocessor', | |||
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
'TextErrorCorrectionPreprocessor' | |||
'TextErrorCorrectionPreprocessor', | |||
'FaqQuestionAnsweringPreprocessor' | |||
], | |||
'space': [ | |||
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
@@ -5,10 +5,12 @@ import uuid | |||
from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
import numpy as np | |||
import torch | |||
from transformers import AutoTokenizer | |||
from modelscope.metainfo import Models, Preprocessors | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.config import ConfigFields | |||
from modelscope.utils.constant import Fields, InputFields, ModeKeys | |||
from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
from modelscope.utils.type_assert import type_assert | |||
@@ -21,7 +23,7 @@ __all__ = [ | |||
'PairSentenceClassificationPreprocessor', | |||
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
'TextErrorCorrectionPreprocessor' | |||
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' | |||
] | |||
@@ -645,3 +647,86 @@ class TextErrorCorrectionPreprocessor(Preprocessor): | |||
sample = dict() | |||
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||
return sample | |||
@PREPROCESSORS.register_module( | |||
Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) | |||
class FaqQuestionAnsweringPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
super(FaqQuestionAnsweringPreprocessor, self).__init__( | |||
model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs) | |||
import os | |||
from transformers import BertTokenizer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile | |||
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
preprocessor_config = Config.from_file( | |||
os.path.join(model_dir, ModelFile.CONFIGURATION)).get( | |||
ConfigFields.preprocessor, {}) | |||
self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) | |||
self.label_dict = None | |||
def pad(self, samples, max_len): | |||
result = [] | |||
for sample in samples: | |||
pad_len = max_len - len(sample[:max_len]) | |||
result.append(sample[:max_len] | |||
+ [self.tokenizer.pad_token_id] * pad_len) | |||
return result | |||
def set_label_dict(self, label_dict): | |||
self.label_dict = label_dict | |||
def get_label(self, label_id): | |||
assert self.label_dict is not None and label_id < len(self.label_dict) | |||
return self.label_dict[label_id] | |||
def encode_plus(self, text): | |||
return [ | |||
self.tokenizer.cls_token_id | |||
] + self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] | |||
@type_assert(object, Dict) | |||
def __call__(self, data: Dict[str, Any], | |||
**preprocessor_param) -> Dict[str, Any]: | |||
TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) | |||
queryset = data['query_set'] | |||
if not isinstance(queryset, list): | |||
queryset = [queryset] | |||
supportset = data['support_set'] | |||
supportset = sorted(supportset, key=lambda d: d['label']) | |||
queryset_tokenized = [self.encode_plus(text) for text in queryset] | |||
supportset_tokenized = [ | |||
self.encode_plus(item['text']) for item in supportset | |||
] | |||
max_len = max( | |||
[len(seq) for seq in queryset_tokenized + supportset_tokenized]) | |||
max_len = min(TMP_MAX_LEN, max_len) | |||
queryset_padded = self.pad(queryset_tokenized, max_len) | |||
supportset_padded = self.pad(supportset_tokenized, max_len) | |||
supportset_labels_ori = [item['label'] for item in supportset] | |||
label_dict = [] | |||
for label in supportset_labels_ori: | |||
if label not in label_dict: | |||
label_dict.append(label) | |||
self.set_label_dict(label_dict) | |||
supportset_labels_ids = [ | |||
label_dict.index(label) for label in supportset_labels_ori | |||
] | |||
return { | |||
'query': queryset_padded, | |||
'support': supportset_padded, | |||
'support_labels': supportset_labels_ids | |||
} | |||
def batch_encode(self, sentence_list: list, max_length=None): | |||
if not max_length: | |||
max_length = self.MAX_LEN | |||
return self.tokenizer.batch_encode_plus( | |||
sentence_list, padding=True, max_length=max_length) |
@@ -95,6 +95,7 @@ class NLPTasks(object): | |||
zero_shot_classification = 'zero-shot-classification' | |||
backbone = 'backbone' | |||
text_error_correction = 'text-error-correction' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
@@ -0,0 +1,85 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import numpy as np | |||
from modelscope.hub.api import HubApi | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.models import Model | |||
from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline | |||
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class FaqQuestionAnsweringTest(unittest.TestCase): | |||
model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base' | |||
param = { | |||
'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], | |||
'support_set': [{ | |||
'text': '卖品代金券怎么用', | |||
'label': '6527856' | |||
}, { | |||
'text': '怎么使用优惠券', | |||
'label': '6527856' | |||
}, { | |||
'text': '这个可以一起领吗', | |||
'label': '1000012000' | |||
}, { | |||
'text': '付款时送的优惠券哪里领', | |||
'label': '1000012000' | |||
}, { | |||
'text': '购物等级怎么长', | |||
'label': '13421097' | |||
}, { | |||
'text': '购物等级二心', | |||
'label': '13421097' | |||
}] | |||
} | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_direct_file_download(self): | |||
cache_path = snapshot_download(self.model_id) | |||
preprocessor = FaqQuestionAnsweringPreprocessor(cache_path) | |||
model = SbertForFaqQuestionAnswering(cache_path) | |||
model.load_checkpoint(cache_path) | |||
pipeline_ins = FaqQuestionAnsweringPipeline( | |||
model, preprocessor=preprocessor) | |||
result = pipeline_ins(self.param) | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained(self.model_id) | |||
preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir) | |||
pipeline_ins = pipeline( | |||
task=Tasks.faq_question_answering, | |||
model=model, | |||
preprocessor=preprocessor) | |||
result = pipeline_ins(self.param) | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_model_name(self): | |||
pipeline_ins = pipeline( | |||
task=Tasks.faq_question_answering, model=self.model_id) | |||
result = pipeline_ins(self.param) | |||
print(result) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_default_model(self): | |||
pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||
print(pipeline_ins(self.param, max_seq_length=20)) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_sentence_embedding(self): | |||
pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||
sentence_vec = pipeline_ins.get_sentence_embedding( | |||
['今天星期六', '明天星期几明天星期几']) | |||
print(np.shape(sentence_vec)) | |||
if __name__ == '__main__': | |||
unittest.main() |