Maas新增FAQ问答模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9797053master
@@ -49,8 +49,8 @@ def handle_http_response(response, logger, cookies, model_id): | |||||
except HTTPError: | except HTTPError: | ||||
if cookies is None: # code in [403] and | if cookies is None: # code in [403] and | ||||
logger.error( | logger.error( | ||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \ | |||||
Please login first.') | |||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||||
private. Please login first.') | |||||
raise | raise | ||||
@@ -138,6 +138,7 @@ class Pipelines(object): | |||||
dialog_state_tracking = 'dialog-state-tracking' | dialog_state_tracking = 'dialog-state-tracking' | ||||
zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
faq_question_answering = 'faq-question-answering' | |||||
conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
# audio tasks | # audio tasks | ||||
@@ -220,6 +221,7 @@ class Preprocessors(object): | |||||
text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | ||||
fill_mask = 'fill-mask' | fill_mask = 'fill-mask' | ||||
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||||
conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
# audio preprocessor | # audio preprocessor | ||||
@@ -6,7 +6,8 @@ import torch | |||||
def point_form(boxes: torch.Tensor) -> torch.Tensor: | def point_form(boxes: torch.Tensor) -> torch.Tensor: | ||||
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. | |||||
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \ | |||||
ground truth data. | |||||
Args: | Args: | ||||
boxes: center-size default boxes from priorbox layers. | boxes: center-size default boxes from priorbox layers. | ||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||||
from .task_models.task_model import SingleBackboneTaskModelBase | from .task_models.task_model import SingleBackboneTaskModelBase | ||||
from .bart_for_text_error_correction import BartForTextErrorCorrection | from .bart_for_text_error_correction import BartForTextErrorCorrection | ||||
from .gpt3 import GPT3ForTextGeneration | from .gpt3 import GPT3ForTextGeneration | ||||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
@@ -44,6 +45,7 @@ else: | |||||
'task_model': ['SingleBackboneTaskModelBase'], | 'task_model': ['SingleBackboneTaskModelBase'], | ||||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | ||||
'gpt3': ['GPT3ForTextGeneration'], | 'gpt3': ['GPT3ForTextGeneration'], | ||||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,249 @@ | |||||
import math | |||||
import os | |||||
from collections import namedtuple | |||||
from typing import Dict | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch import Tensor | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.nlp.structbert import SbertConfig, SbertModel | |||||
from modelscope.models.nlp.task_models.task_model import BaseTaskModel | |||||
from modelscope.utils.config import Config, ConfigFields | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
__all__ = ['SbertForFaqQuestionAnswering'] | |||||
class SbertForFaqQuestionAnsweringBase(BaseTaskModel): | |||||
"""base class for faq models | |||||
""" | |||||
def __init__(self, model_dir, *args, **kwargs): | |||||
super(SbertForFaqQuestionAnsweringBase, | |||||
self).__init__(model_dir, *args, **kwargs) | |||||
backbone_cfg = SbertConfig.from_pretrained(model_dir) | |||||
self.bert = SbertModel(backbone_cfg) | |||||
model_config = Config.from_file( | |||||
os.path.join(model_dir, | |||||
ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) | |||||
metric = model_config.get('metric', 'cosine') | |||||
pooling_method = model_config.get('pooling', 'avg') | |||||
Arg = namedtuple('args', [ | |||||
'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' | |||||
]) | |||||
args = Arg( | |||||
metrics=metric, | |||||
proj_hidden_size=self.bert.config.hidden_size, | |||||
hidden_size=self.bert.config.hidden_size, | |||||
dropout=0.0, | |||||
pooling=pooling_method) | |||||
self.metrics_layer = MetricsLayer(args) | |||||
self.pooling = PoolingLayer(args) | |||||
def _get_onehot_labels(self, labels, support_size, num_cls): | |||||
labels_ = labels.view(support_size, 1) | |||||
target_oh = torch.zeros(support_size, num_cls).to(labels) | |||||
target_oh.scatter_(dim=1, index=labels_, value=1) | |||||
return target_oh.view(support_size, num_cls).float() | |||||
def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): | |||||
input_ids = inputs['input_ids'] | |||||
input_mask = inputs['attention_mask'] | |||||
if not isinstance(input_ids, Tensor): | |||||
input_ids = torch.IntTensor(input_ids) | |||||
if not isinstance(input_mask, Tensor): | |||||
input_mask = torch.IntTensor(input_mask) | |||||
rst = self.bert(input_ids, input_mask) | |||||
last_hidden_states = rst.last_hidden_state | |||||
if len(input_mask.shape) == 2: | |||||
input_mask = input_mask.unsqueeze(-1) | |||||
pooled_representation = self.pooling(last_hidden_states, input_mask) | |||||
return pooled_representation | |||||
@MODELS.register_module( | |||||
Tasks.faq_question_answering, module_name=Models.structbert) | |||||
class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): | |||||
_backbone_prefix = '' | |||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
assert not self.training | |||||
query = input['query'] | |||||
support = input['support'] | |||||
if isinstance(query, list): | |||||
query = torch.stack(query) | |||||
if isinstance(support, list): | |||||
support = torch.stack(support) | |||||
n_query = query.shape[0] | |||||
n_support = support.shape[0] | |||||
query_mask = torch.ne(query, 0).view([n_query, -1]) | |||||
support_mask = torch.ne(support, 0).view([n_support, -1]) | |||||
support_labels = input['support_labels'] | |||||
num_cls = torch.max(support_labels) + 1 | |||||
onehot_labels = self._get_onehot_labels(support_labels, n_support, | |||||
num_cls) | |||||
input_ids = torch.cat([query, support]) | |||||
input_mask = torch.cat([query_mask, support_mask], dim=0) | |||||
pooled_representation = self.forward_sentence_embedding({ | |||||
'input_ids': | |||||
input_ids, | |||||
'attention_mask': | |||||
input_mask | |||||
}) | |||||
z_query = pooled_representation[:n_query] | |||||
z_support = pooled_representation[n_query:] | |||||
cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 | |||||
protos = torch.matmul(onehot_labels.transpose(0, 1), | |||||
z_support) / cls_n_support.unsqueeze(-1) | |||||
scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) | |||||
if self.metrics_layer.name == 'relation': | |||||
scores = torch.sigmoid(scores) | |||||
return {'scores': scores} | |||||
activations = { | |||||
'relu': F.relu, | |||||
'tanh': torch.tanh, | |||||
'linear': lambda x: x, | |||||
} | |||||
activation_coeffs = { | |||||
'relu': math.sqrt(2), | |||||
'tanh': 5 / 3, | |||||
'linear': 1., | |||||
} | |||||
class LinearProjection(nn.Module): | |||||
def __init__(self, | |||||
in_features, | |||||
out_features, | |||||
activation='linear', | |||||
bias=True): | |||||
super().__init__() | |||||
self.activation = activations[activation] | |||||
activation_coeff = activation_coeffs[activation] | |||||
linear = nn.Linear(in_features, out_features, bias=bias) | |||||
nn.init.normal_( | |||||
linear.weight, std=math.sqrt(1. / in_features) * activation_coeff) | |||||
if bias: | |||||
nn.init.zeros_(linear.bias) | |||||
self.model = nn.utils.weight_norm(linear) | |||||
def forward(self, x): | |||||
return self.activation(self.model(x)) | |||||
class RelationModule(nn.Module): | |||||
def __init__(self, args): | |||||
super(RelationModule, self).__init__() | |||||
input_size = args.proj_hidden_size * 4 | |||||
self.prediction = torch.nn.Sequential( | |||||
LinearProjection( | |||||
input_size, args.proj_hidden_size * 4, activation='relu'), | |||||
nn.Dropout(args.dropout), | |||||
LinearProjection(args.proj_hidden_size * 4, 1)) | |||||
def forward(self, query, protos): | |||||
n_cls = protos.shape[0] | |||||
n_query = query.shape[0] | |||||
protos = protos.unsqueeze(0).repeat(n_query, 1, 1) | |||||
query = query.unsqueeze(1).repeat(1, n_cls, 1) | |||||
input_feat = torch.cat( | |||||
[query, protos, (protos - query).abs(), query * protos], dim=-1) | |||||
dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1] | |||||
return dists.squeeze(-1) | |||||
class MetricsLayer(nn.Module): | |||||
def __init__(self, args): | |||||
super(MetricsLayer, self).__init__() | |||||
self.args = args | |||||
assert args.metrics in ('relation', 'cosine') | |||||
if args.metrics == 'relation': | |||||
self.relation_net = RelationModule(args) | |||||
@property | |||||
def name(self): | |||||
return self.args.metrics | |||||
def forward(self, query, protos): | |||||
""" query : [bsz, n_query, dim] | |||||
support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim] | |||||
""" | |||||
if self.args.metrics == 'cosine': | |||||
supervised_dists = self.cosine_similarity(query, protos) | |||||
if self.training: | |||||
supervised_dists *= 5 | |||||
elif self.args.metrics in ('relation', ): | |||||
supervised_dists = self.relation_net(query, protos) | |||||
else: | |||||
raise NotImplementedError | |||||
return supervised_dists | |||||
def cosine_similarity(self, x, y): | |||||
# x=[bsz, n_query, dim] | |||||
# y=[bsz, n_cls, dim] | |||||
n_query = x.shape[0] | |||||
n_cls = y.shape[0] | |||||
dim = x.shape[-1] | |||||
x = x.unsqueeze(1).expand([n_query, n_cls, dim]) | |||||
y = y.unsqueeze(0).expand([n_query, n_cls, dim]) | |||||
return F.cosine_similarity(x, y, -1) | |||||
class AveragePooling(nn.Module): | |||||
def forward(self, x, mask, dim=1): | |||||
return torch.sum( | |||||
x * mask.float(), dim=dim) / torch.sum( | |||||
mask.float(), dim=dim) | |||||
class AttnPooling(nn.Module): | |||||
def __init__(self, input_size, hidden_size=None, output_size=None): | |||||
super().__init__() | |||||
self.input_proj = nn.Sequential( | |||||
LinearProjection(input_size, hidden_size), nn.Tanh(), | |||||
LinearProjection(hidden_size, 1, bias=False)) | |||||
self.output_proj = LinearProjection( | |||||
input_size, output_size) if output_size else lambda x: x | |||||
def forward(self, x, mask): | |||||
score = self.input_proj(x) | |||||
score = score * mask.float() + -1e4 * (1. - mask.float()) | |||||
score = F.softmax(score, dim=1) | |||||
features = self.output_proj(x) | |||||
return torch.matmul(score.transpose(1, 2), features).squeeze(1) | |||||
class PoolingLayer(nn.Module): | |||||
def __init__(self, args): | |||||
super(PoolingLayer, self).__init__() | |||||
if args.pooling == 'attn': | |||||
self.pooling = AttnPooling(args.proj_hidden_size, | |||||
args.proj_hidden_size, | |||||
args.proj_hidden_size) | |||||
elif args.pooling == 'avg': | |||||
self.pooling = AveragePooling() | |||||
else: | |||||
raise NotImplementedError(args.pooling) | |||||
def forward(self, x, mask): | |||||
return self.pooling(x, mask) |
@@ -7,6 +7,7 @@ class OutputKeys(object): | |||||
LOSS = 'loss' | LOSS = 'loss' | ||||
LOGITS = 'logits' | LOGITS = 'logits' | ||||
SCORES = 'scores' | SCORES = 'scores' | ||||
SCORE = 'score' | |||||
LABEL = 'label' | LABEL = 'label' | ||||
LABELS = 'labels' | LABELS = 'labels' | ||||
INPUT_IDS = 'input_ids' | INPUT_IDS = 'input_ids' | ||||
@@ -504,6 +505,16 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
# { | |||||
# 'output': [ | |||||
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | |||||
# {'label': '13421097', 'score': 2.2825044965202324e-08}], | |||||
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||||
# {'label': '13421097', 'score': 2.75914817393641e-06}], | |||||
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||||
# {'label': '13421097', 'score': 2.75914817393641e-06}]] | |||||
# } | |||||
Tasks.faq_question_answering: [OutputKeys.OUTPUT], | |||||
# image person reid result for single sample | # image person reid result for single sample | ||||
# { | # { | ||||
# "img_embedding": np.array with shape [1, D], | # "img_embedding": np.array with shape [1, D], | ||||
@@ -129,6 +129,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
'damo/cv_convnextTiny_ocr-recognition-general_damo'), | 'damo/cv_convnextTiny_ocr-recognition-general_damo'), | ||||
Tasks.skin_retouching: (Pipelines.skin_retouching, | Tasks.skin_retouching: (Pipelines.skin_retouching, | ||||
'damo/cv_unet_skin-retouching'), | 'damo/cv_unet_skin-retouching'), | ||||
Tasks.faq_question_answering: | |||||
(Pipelines.faq_question_answering, | |||||
'damo/nlp_structbert_faq-question-answering_chinese-base'), | |||||
Tasks.crowd_counting: (Pipelines.crowd_counting, | Tasks.crowd_counting: (Pipelines.crowd_counting, | ||||
'damo/cv_hrnet_crowd-counting_dcanet'), | 'damo/cv_hrnet_crowd-counting_dcanet'), | ||||
Tasks.video_single_object_tracking: | Tasks.video_single_object_tracking: | ||||
@@ -218,7 +221,6 @@ def pipeline(task: str = None, | |||||
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | ||||
model = normalize_model_input(model, model_revision) | model = normalize_model_input(model, model_revision) | ||||
if pipeline_name is None: | if pipeline_name is None: | ||||
# get default pipeline for this task | # get default pipeline for this task | ||||
if isinstance(model, str) \ | if isinstance(model, str) \ | ||||
@@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||||
from .summarization_pipeline import SummarizationPipeline | from .summarization_pipeline import SummarizationPipeline | ||||
from .text_classification_pipeline import TextClassificationPipeline | from .text_classification_pipeline import TextClassificationPipeline | ||||
from .text_error_correction_pipeline import TextErrorCorrectionPipeline | from .text_error_correction_pipeline import TextErrorCorrectionPipeline | ||||
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
@@ -44,7 +45,8 @@ else: | |||||
'translation_pipeline': ['TranslationPipeline'], | 'translation_pipeline': ['TranslationPipeline'], | ||||
'summarization_pipeline': ['SummarizationPipeline'], | 'summarization_pipeline': ['SummarizationPipeline'], | ||||
'text_classification_pipeline': ['TextClassificationPipeline'], | 'text_classification_pipeline': ['TextClassificationPipeline'], | ||||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] | |||||
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||||
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,76 @@ | |||||
from typing import Any, Dict, Union | |||||
import torch | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models import Model | |||||
from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||||
from modelscope.utils.constant import Tasks | |||||
__all__ = ['FaqQuestionAnsweringPipeline'] | |||||
@PIPELINES.register_module( | |||||
Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering) | |||||
class FaqQuestionAnsweringPipeline(Pipeline): | |||||
def __init__(self, | |||||
model: Union[str, SbertForFaqQuestionAnswering], | |||||
preprocessor: FaqQuestionAnsweringPreprocessor = None, | |||||
**kwargs): | |||||
model = model if isinstance( | |||||
model, | |||||
SbertForFaqQuestionAnswering) else Model.from_pretrained(model) | |||||
model.eval() | |||||
if preprocessor is None: | |||||
preprocessor = FaqQuestionAnsweringPreprocessor( | |||||
model.model_dir, **kwargs) | |||||
self.preprocessor = preprocessor | |||||
super(FaqQuestionAnsweringPipeline, self).__init__( | |||||
model=model, preprocessor=preprocessor, **kwargs) | |||||
def _sanitize_parameters(self, **pipeline_parameters): | |||||
return pipeline_parameters, pipeline_parameters, pipeline_parameters | |||||
def get_sentence_embedding(self, inputs, max_len=None): | |||||
inputs = self.preprocessor.batch_encode(inputs, max_length=max_len) | |||||
sentence_vecs = self.model.forward_sentence_embedding(inputs) | |||||
sentence_vecs = sentence_vecs.detach().tolist() | |||||
return sentence_vecs | |||||
def forward(self, inputs: [list, Dict[str, Any]], | |||||
**forward_params) -> Dict[str, Any]: | |||||
with torch.no_grad(): | |||||
return self.model(inputs) | |||||
def postprocess(self, inputs: [list, Dict[str, Any]], | |||||
**postprocess_params) -> Dict[str, Any]: | |||||
scores = inputs['scores'] | |||||
labels = [] | |||||
for item in scores: | |||||
tmplabels = [ | |||||
self.preprocessor.get_label(label_id) | |||||
for label_id in range(len(item)) | |||||
] | |||||
labels.append(tmplabels) | |||||
predictions = [] | |||||
for tmp_scores, tmp_labels in zip(scores.tolist(), labels): | |||||
prediction = [] | |||||
for score, label in zip(tmp_scores, tmp_labels): | |||||
prediction.append({ | |||||
OutputKeys.LABEL: label, | |||||
OutputKeys.SCORE: score | |||||
}) | |||||
predictions.append( | |||||
list( | |||||
sorted( | |||||
prediction, | |||||
key=lambda d: d[OutputKeys.SCORE], | |||||
reverse=True))) | |||||
return {OutputKeys.OUTPUT: predictions} |
@@ -21,7 +21,8 @@ if TYPE_CHECKING: | |||||
SingleSentenceClassificationPreprocessor, | SingleSentenceClassificationPreprocessor, | ||||
PairSentenceClassificationPreprocessor, | PairSentenceClassificationPreprocessor, | ||||
FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | ||||
NERPreprocessor, TextErrorCorrectionPreprocessor) | |||||
NERPreprocessor, TextErrorCorrectionPreprocessor, | |||||
FaqQuestionAnsweringPreprocessor) | |||||
from .space import (DialogIntentPredictionPreprocessor, | from .space import (DialogIntentPredictionPreprocessor, | ||||
DialogModelingPreprocessor, | DialogModelingPreprocessor, | ||||
DialogStateTrackingPreprocessor) | DialogStateTrackingPreprocessor) | ||||
@@ -48,7 +49,8 @@ else: | |||||
'SingleSentenceClassificationPreprocessor', | 'SingleSentenceClassificationPreprocessor', | ||||
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
'TextErrorCorrectionPreprocessor' | |||||
'TextErrorCorrectionPreprocessor', | |||||
'FaqQuestionAnsweringPreprocessor' | |||||
], | ], | ||||
'space': [ | 'space': [ | ||||
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | ||||
@@ -5,10 +5,12 @@ import uuid | |||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union | from typing import Any, Dict, Iterable, Optional, Tuple, Union | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from transformers import AutoTokenizer | from transformers import AutoTokenizer | ||||
from modelscope.metainfo import Models, Preprocessors | from modelscope.metainfo import Models, Preprocessors | ||||
from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
from modelscope.utils.config import ConfigFields | |||||
from modelscope.utils.constant import Fields, InputFields, ModeKeys | from modelscope.utils.constant import Fields, InputFields, ModeKeys | ||||
from modelscope.utils.hub import get_model_type, parse_label_mapping | from modelscope.utils.hub import get_model_type, parse_label_mapping | ||||
from modelscope.utils.type_assert import type_assert | from modelscope.utils.type_assert import type_assert | ||||
@@ -21,7 +23,7 @@ __all__ = [ | |||||
'PairSentenceClassificationPreprocessor', | 'PairSentenceClassificationPreprocessor', | ||||
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | ||||
'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | ||||
'TextErrorCorrectionPreprocessor' | |||||
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' | |||||
] | ] | ||||
@@ -645,3 +647,86 @@ class TextErrorCorrectionPreprocessor(Preprocessor): | |||||
sample = dict() | sample = dict() | ||||
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | ||||
return sample | return sample | ||||
@PREPROCESSORS.register_module( | |||||
Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) | |||||
class FaqQuestionAnsweringPreprocessor(Preprocessor): | |||||
def __init__(self, model_dir: str, *args, **kwargs): | |||||
super(FaqQuestionAnsweringPreprocessor, self).__init__( | |||||
model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs) | |||||
import os | |||||
from transformers import BertTokenizer | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile | |||||
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
preprocessor_config = Config.from_file( | |||||
os.path.join(model_dir, ModelFile.CONFIGURATION)).get( | |||||
ConfigFields.preprocessor, {}) | |||||
self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) | |||||
self.label_dict = None | |||||
def pad(self, samples, max_len): | |||||
result = [] | |||||
for sample in samples: | |||||
pad_len = max_len - len(sample[:max_len]) | |||||
result.append(sample[:max_len] | |||||
+ [self.tokenizer.pad_token_id] * pad_len) | |||||
return result | |||||
def set_label_dict(self, label_dict): | |||||
self.label_dict = label_dict | |||||
def get_label(self, label_id): | |||||
assert self.label_dict is not None and label_id < len(self.label_dict) | |||||
return self.label_dict[label_id] | |||||
def encode_plus(self, text): | |||||
return [ | |||||
self.tokenizer.cls_token_id | |||||
] + self.tokenizer.convert_tokens_to_ids( | |||||
self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] | |||||
@type_assert(object, Dict) | |||||
def __call__(self, data: Dict[str, Any], | |||||
**preprocessor_param) -> Dict[str, Any]: | |||||
TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) | |||||
queryset = data['query_set'] | |||||
if not isinstance(queryset, list): | |||||
queryset = [queryset] | |||||
supportset = data['support_set'] | |||||
supportset = sorted(supportset, key=lambda d: d['label']) | |||||
queryset_tokenized = [self.encode_plus(text) for text in queryset] | |||||
supportset_tokenized = [ | |||||
self.encode_plus(item['text']) for item in supportset | |||||
] | |||||
max_len = max( | |||||
[len(seq) for seq in queryset_tokenized + supportset_tokenized]) | |||||
max_len = min(TMP_MAX_LEN, max_len) | |||||
queryset_padded = self.pad(queryset_tokenized, max_len) | |||||
supportset_padded = self.pad(supportset_tokenized, max_len) | |||||
supportset_labels_ori = [item['label'] for item in supportset] | |||||
label_dict = [] | |||||
for label in supportset_labels_ori: | |||||
if label not in label_dict: | |||||
label_dict.append(label) | |||||
self.set_label_dict(label_dict) | |||||
supportset_labels_ids = [ | |||||
label_dict.index(label) for label in supportset_labels_ori | |||||
] | |||||
return { | |||||
'query': queryset_padded, | |||||
'support': supportset_padded, | |||||
'support_labels': supportset_labels_ids | |||||
} | |||||
def batch_encode(self, sentence_list: list, max_length=None): | |||||
if not max_length: | |||||
max_length = self.MAX_LEN | |||||
return self.tokenizer.batch_encode_plus( | |||||
sentence_list, padding=True, max_length=max_length) |
@@ -95,6 +95,7 @@ class NLPTasks(object): | |||||
zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
backbone = 'backbone' | backbone = 'backbone' | ||||
text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||
faq_question_answering = 'faq-question-answering' | |||||
conversational_text_to_sql = 'conversational-text-to-sql' | conversational_text_to_sql = 'conversational-text-to-sql' | ||||
@@ -0,0 +1,85 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
import numpy as np | |||||
from modelscope.hub.api import HubApi | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.models import Model | |||||
from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline | |||||
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class FaqQuestionAnsweringTest(unittest.TestCase): | |||||
model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base' | |||||
param = { | |||||
'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], | |||||
'support_set': [{ | |||||
'text': '卖品代金券怎么用', | |||||
'label': '6527856' | |||||
}, { | |||||
'text': '怎么使用优惠券', | |||||
'label': '6527856' | |||||
}, { | |||||
'text': '这个可以一起领吗', | |||||
'label': '1000012000' | |||||
}, { | |||||
'text': '付款时送的优惠券哪里领', | |||||
'label': '1000012000' | |||||
}, { | |||||
'text': '购物等级怎么长', | |||||
'label': '13421097' | |||||
}, { | |||||
'text': '购物等级二心', | |||||
'label': '13421097' | |||||
}] | |||||
} | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_direct_file_download(self): | |||||
cache_path = snapshot_download(self.model_id) | |||||
preprocessor = FaqQuestionAnsweringPreprocessor(cache_path) | |||||
model = SbertForFaqQuestionAnswering(cache_path) | |||||
model.load_checkpoint(cache_path) | |||||
pipeline_ins = FaqQuestionAnsweringPipeline( | |||||
model, preprocessor=preprocessor) | |||||
result = pipeline_ins(self.param) | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_with_model_from_modelhub(self): | |||||
model = Model.from_pretrained(self.model_id) | |||||
preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir) | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.faq_question_answering, | |||||
model=model, | |||||
preprocessor=preprocessor) | |||||
result = pipeline_ins(self.param) | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name(self): | |||||
pipeline_ins = pipeline( | |||||
task=Tasks.faq_question_answering, model=self.model_id) | |||||
result = pipeline_ins(self.param) | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_with_default_model(self): | |||||
pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||||
print(pipeline_ins(self.param, max_seq_length=20)) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_sentence_embedding(self): | |||||
pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||||
sentence_vec = pipeline_ins.get_sentence_embedding( | |||||
['今天星期六', '明天星期几明天星期几']) | |||||
print(np.shape(sentence_vec)) | |||||
if __name__ == '__main__': | |||||
unittest.main() |