From dc45fce542abab709bb0559248ba2712224a9df6 Mon Sep 17 00:00:00 2001 From: "tanfan.zjh" Date: Fri, 26 Aug 2022 13:06:41 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E6=96=B0=E5=A2=9EFAQ=E9=97=AE?= =?UTF-8?q?=E7=AD=94=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Maas新增FAQ问答模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9797053 --- modelscope/hub/errors.py | 4 +- modelscope/metainfo.py | 2 + .../skin_retouching/retinaface/box_utils.py | 3 +- modelscope/models/nlp/__init__.py | 2 + .../nlp/sbert_for_faq_question_answering.py | 249 ++++++++++++++++++ modelscope/outputs.py | 11 + modelscope/pipelines/builder.py | 4 +- modelscope/pipelines/nlp/__init__.py | 4 +- .../nlp/faq_question_answering_pipeline.py | 76 ++++++ modelscope/preprocessors/__init__.py | 6 +- modelscope/preprocessors/nlp.py | 87 +++++- modelscope/utils/constant.py | 1 + .../pipelines/test_faq_question_answering.py | 85 ++++++ 13 files changed, 526 insertions(+), 8 deletions(-) create mode 100644 modelscope/models/nlp/sbert_for_faq_question_answering.py create mode 100644 modelscope/pipelines/nlp/faq_question_answering_pipeline.py create mode 100644 tests/pipelines/test_faq_question_answering.py diff --git a/modelscope/hub/errors.py b/modelscope/hub/errors.py index ecd4e1da..e9c008b0 100644 --- a/modelscope/hub/errors.py +++ b/modelscope/hub/errors.py @@ -49,8 +49,8 @@ def handle_http_response(response, logger, cookies, model_id): except HTTPError: if cookies is None: # code in [403] and logger.error( - f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \ - Please login first.') + f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ + private. Please login first.') raise diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 8e21c00b..6ea03610 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -138,6 +138,7 @@ class Pipelines(object): dialog_state_tracking = 'dialog-state-tracking' zero_shot_classification = 'zero-shot-classification' text_error_correction = 'text-error-correction' + faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' # audio tasks @@ -220,6 +221,7 @@ class Preprocessors(object): text_error_correction = 'text-error-correction' word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' fill_mask = 'fill-mask' + faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' conversational_text_to_sql = 'conversational-text-to-sql' # audio preprocessor diff --git a/modelscope/models/cv/skin_retouching/retinaface/box_utils.py b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py index 89cf8bf6..a4aeffd1 100644 --- a/modelscope/models/cv/skin_retouching/retinaface/box_utils.py +++ b/modelscope/models/cv/skin_retouching/retinaface/box_utils.py @@ -6,7 +6,8 @@ import torch def point_form(boxes: torch.Tensor) -> torch.Tensor: - """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. + """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \ + ground truth data. Args: boxes: center-size default boxes from priorbox layers. diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 3fd76f98..13be9096 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .task_models.task_model import SingleBackboneTaskModelBase from .bart_for_text_error_correction import BartForTextErrorCorrection from .gpt3 import GPT3ForTextGeneration + from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering else: _import_structure = { @@ -44,6 +45,7 @@ else: 'task_model': ['SingleBackboneTaskModelBase'], 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], 'gpt3': ['GPT3ForTextGeneration'], + 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] } import sys diff --git a/modelscope/models/nlp/sbert_for_faq_question_answering.py b/modelscope/models/nlp/sbert_for_faq_question_answering.py new file mode 100644 index 00000000..23ccdcc5 --- /dev/null +++ b/modelscope/models/nlp/sbert_for_faq_question_answering.py @@ -0,0 +1,249 @@ +import math +import os +from collections import namedtuple +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.nlp.structbert import SbertConfig, SbertModel +from modelscope.models.nlp.task_models.task_model import BaseTaskModel +from modelscope.utils.config import Config, ConfigFields +from modelscope.utils.constant import ModelFile, Tasks + +__all__ = ['SbertForFaqQuestionAnswering'] + + +class SbertForFaqQuestionAnsweringBase(BaseTaskModel): + """base class for faq models + """ + + def __init__(self, model_dir, *args, **kwargs): + super(SbertForFaqQuestionAnsweringBase, + self).__init__(model_dir, *args, **kwargs) + + backbone_cfg = SbertConfig.from_pretrained(model_dir) + self.bert = SbertModel(backbone_cfg) + + model_config = Config.from_file( + os.path.join(model_dir, + ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) + + metric = model_config.get('metric', 'cosine') + pooling_method = model_config.get('pooling', 'avg') + + Arg = namedtuple('args', [ + 'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' + ]) + args = Arg( + metrics=metric, + proj_hidden_size=self.bert.config.hidden_size, + hidden_size=self.bert.config.hidden_size, + dropout=0.0, + pooling=pooling_method) + + self.metrics_layer = MetricsLayer(args) + self.pooling = PoolingLayer(args) + + def _get_onehot_labels(self, labels, support_size, num_cls): + labels_ = labels.view(support_size, 1) + target_oh = torch.zeros(support_size, num_cls).to(labels) + target_oh.scatter_(dim=1, index=labels_, value=1) + return target_oh.view(support_size, num_cls).float() + + def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): + input_ids = inputs['input_ids'] + input_mask = inputs['attention_mask'] + if not isinstance(input_ids, Tensor): + input_ids = torch.IntTensor(input_ids) + if not isinstance(input_mask, Tensor): + input_mask = torch.IntTensor(input_mask) + rst = self.bert(input_ids, input_mask) + last_hidden_states = rst.last_hidden_state + if len(input_mask.shape) == 2: + input_mask = input_mask.unsqueeze(-1) + pooled_representation = self.pooling(last_hidden_states, input_mask) + return pooled_representation + + +@MODELS.register_module( + Tasks.faq_question_answering, module_name=Models.structbert) +class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): + _backbone_prefix = '' + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + assert not self.training + query = input['query'] + support = input['support'] + if isinstance(query, list): + query = torch.stack(query) + if isinstance(support, list): + support = torch.stack(support) + n_query = query.shape[0] + n_support = support.shape[0] + query_mask = torch.ne(query, 0).view([n_query, -1]) + support_mask = torch.ne(support, 0).view([n_support, -1]) + + support_labels = input['support_labels'] + num_cls = torch.max(support_labels) + 1 + onehot_labels = self._get_onehot_labels(support_labels, n_support, + num_cls) + + input_ids = torch.cat([query, support]) + input_mask = torch.cat([query_mask, support_mask], dim=0) + pooled_representation = self.forward_sentence_embedding({ + 'input_ids': + input_ids, + 'attention_mask': + input_mask + }) + z_query = pooled_representation[:n_query] + z_support = pooled_representation[n_query:] + cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 + protos = torch.matmul(onehot_labels.transpose(0, 1), + z_support) / cls_n_support.unsqueeze(-1) + scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) + if self.metrics_layer.name == 'relation': + scores = torch.sigmoid(scores) + return {'scores': scores} + + +activations = { + 'relu': F.relu, + 'tanh': torch.tanh, + 'linear': lambda x: x, +} + +activation_coeffs = { + 'relu': math.sqrt(2), + 'tanh': 5 / 3, + 'linear': 1., +} + + +class LinearProjection(nn.Module): + + def __init__(self, + in_features, + out_features, + activation='linear', + bias=True): + super().__init__() + self.activation = activations[activation] + activation_coeff = activation_coeffs[activation] + linear = nn.Linear(in_features, out_features, bias=bias) + nn.init.normal_( + linear.weight, std=math.sqrt(1. / in_features) * activation_coeff) + if bias: + nn.init.zeros_(linear.bias) + self.model = nn.utils.weight_norm(linear) + + def forward(self, x): + return self.activation(self.model(x)) + + +class RelationModule(nn.Module): + + def __init__(self, args): + super(RelationModule, self).__init__() + input_size = args.proj_hidden_size * 4 + self.prediction = torch.nn.Sequential( + LinearProjection( + input_size, args.proj_hidden_size * 4, activation='relu'), + nn.Dropout(args.dropout), + LinearProjection(args.proj_hidden_size * 4, 1)) + + def forward(self, query, protos): + n_cls = protos.shape[0] + n_query = query.shape[0] + protos = protos.unsqueeze(0).repeat(n_query, 1, 1) + query = query.unsqueeze(1).repeat(1, n_cls, 1) + input_feat = torch.cat( + [query, protos, (protos - query).abs(), query * protos], dim=-1) + dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1] + return dists.squeeze(-1) + + +class MetricsLayer(nn.Module): + + def __init__(self, args): + super(MetricsLayer, self).__init__() + self.args = args + assert args.metrics in ('relation', 'cosine') + if args.metrics == 'relation': + self.relation_net = RelationModule(args) + + @property + def name(self): + return self.args.metrics + + def forward(self, query, protos): + """ query : [bsz, n_query, dim] + support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim] + """ + if self.args.metrics == 'cosine': + supervised_dists = self.cosine_similarity(query, protos) + if self.training: + supervised_dists *= 5 + elif self.args.metrics in ('relation', ): + supervised_dists = self.relation_net(query, protos) + else: + raise NotImplementedError + return supervised_dists + + def cosine_similarity(self, x, y): + # x=[bsz, n_query, dim] + # y=[bsz, n_cls, dim] + n_query = x.shape[0] + n_cls = y.shape[0] + dim = x.shape[-1] + x = x.unsqueeze(1).expand([n_query, n_cls, dim]) + y = y.unsqueeze(0).expand([n_query, n_cls, dim]) + return F.cosine_similarity(x, y, -1) + + +class AveragePooling(nn.Module): + + def forward(self, x, mask, dim=1): + return torch.sum( + x * mask.float(), dim=dim) / torch.sum( + mask.float(), dim=dim) + + +class AttnPooling(nn.Module): + + def __init__(self, input_size, hidden_size=None, output_size=None): + super().__init__() + self.input_proj = nn.Sequential( + LinearProjection(input_size, hidden_size), nn.Tanh(), + LinearProjection(hidden_size, 1, bias=False)) + self.output_proj = LinearProjection( + input_size, output_size) if output_size else lambda x: x + + def forward(self, x, mask): + score = self.input_proj(x) + score = score * mask.float() + -1e4 * (1. - mask.float()) + score = F.softmax(score, dim=1) + features = self.output_proj(x) + return torch.matmul(score.transpose(1, 2), features).squeeze(1) + + +class PoolingLayer(nn.Module): + + def __init__(self, args): + super(PoolingLayer, self).__init__() + if args.pooling == 'attn': + self.pooling = AttnPooling(args.proj_hidden_size, + args.proj_hidden_size, + args.proj_hidden_size) + elif args.pooling == 'avg': + self.pooling = AveragePooling() + else: + raise NotImplementedError(args.pooling) + + def forward(self, x, mask): + return self.pooling(x, mask) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 640d67fa..2edd76a2 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -7,6 +7,7 @@ class OutputKeys(object): LOSS = 'loss' LOGITS = 'logits' SCORES = 'scores' + SCORE = 'score' LABEL = 'label' LABELS = 'labels' INPUT_IDS = 'input_ids' @@ -504,6 +505,16 @@ TASK_OUTPUTS = { # } Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], + # { + # 'output': [ + # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, + # {'label': '13421097', 'score': 2.2825044965202324e-08}], + # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, + # {'label': '13421097', 'score': 2.75914817393641e-06}], + # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, + # {'label': '13421097', 'score': 2.75914817393641e-06}]] + # } + Tasks.faq_question_answering: [OutputKeys.OUTPUT], # image person reid result for single sample # { # "img_embedding": np.array with shape [1, D], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 52dfa41b..fa6705a7 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -129,6 +129,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_convnextTiny_ocr-recognition-general_damo'), Tasks.skin_retouching: (Pipelines.skin_retouching, 'damo/cv_unet_skin-retouching'), + Tasks.faq_question_answering: + (Pipelines.faq_question_answering, + 'damo/nlp_structbert_faq-question-answering_chinese-base'), Tasks.crowd_counting: (Pipelines.crowd_counting, 'damo/cv_hrnet_crowd-counting_dcanet'), Tasks.video_single_object_tracking: @@ -218,7 +221,6 @@ def pipeline(task: str = None, f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' model = normalize_model_input(model, model_revision) - if pipeline_name is None: # get default pipeline for this task if isinstance(model, str) \ diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 0cdb633c..51803872 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .summarization_pipeline import SummarizationPipeline from .text_classification_pipeline import TextClassificationPipeline from .text_error_correction_pipeline import TextErrorCorrectionPipeline + from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline else: _import_structure = { @@ -44,7 +45,8 @@ else: 'translation_pipeline': ['TranslationPipeline'], 'summarization_pipeline': ['SummarizationPipeline'], 'text_classification_pipeline': ['TextClassificationPipeline'], - 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] + 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], + 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] } import sys diff --git a/modelscope/pipelines/nlp/faq_question_answering_pipeline.py b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py new file mode 100644 index 00000000..65831a17 --- /dev/null +++ b/modelscope/pipelines/nlp/faq_question_answering_pipeline.py @@ -0,0 +1,76 @@ +from typing import Any, Dict, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.nlp import SbertForFaqQuestionAnswering +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['FaqQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering) +class FaqQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[str, SbertForFaqQuestionAnswering], + preprocessor: FaqQuestionAnsweringPreprocessor = None, + **kwargs): + model = model if isinstance( + model, + SbertForFaqQuestionAnswering) else Model.from_pretrained(model) + model.eval() + if preprocessor is None: + preprocessor = FaqQuestionAnsweringPreprocessor( + model.model_dir, **kwargs) + self.preprocessor = preprocessor + super(FaqQuestionAnsweringPipeline, self).__init__( + model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + + def get_sentence_embedding(self, inputs, max_len=None): + inputs = self.preprocessor.batch_encode(inputs, max_length=max_len) + sentence_vecs = self.model.forward_sentence_embedding(inputs) + sentence_vecs = sentence_vecs.detach().tolist() + return sentence_vecs + + def forward(self, inputs: [list, Dict[str, Any]], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return self.model(inputs) + + def postprocess(self, inputs: [list, Dict[str, Any]], + **postprocess_params) -> Dict[str, Any]: + scores = inputs['scores'] + labels = [] + for item in scores: + tmplabels = [ + self.preprocessor.get_label(label_id) + for label_id in range(len(item)) + ] + labels.append(tmplabels) + + predictions = [] + for tmp_scores, tmp_labels in zip(scores.tolist(), labels): + prediction = [] + for score, label in zip(tmp_scores, tmp_labels): + prediction.append({ + OutputKeys.LABEL: label, + OutputKeys.SCORE: score + }) + predictions.append( + list( + sorted( + prediction, + key=lambda d: d[OutputKeys.SCORE], + reverse=True))) + + return {OutputKeys.OUTPUT: predictions} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 0328b91a..ce9df454 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -21,7 +21,8 @@ if TYPE_CHECKING: SingleSentenceClassificationPreprocessor, PairSentenceClassificationPreprocessor, FillMaskPreprocessor, ZeroShotClassificationPreprocessor, - NERPreprocessor, TextErrorCorrectionPreprocessor) + NERPreprocessor, TextErrorCorrectionPreprocessor, + FaqQuestionAnsweringPreprocessor) from .space import (DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, DialogStateTrackingPreprocessor) @@ -48,7 +49,8 @@ else: 'SingleSentenceClassificationPreprocessor', 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', - 'TextErrorCorrectionPreprocessor' + 'TextErrorCorrectionPreprocessor', + 'FaqQuestionAnsweringPreprocessor' ], 'space': [ 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 222a219a..094cbfe2 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -5,10 +5,12 @@ import uuid from typing import Any, Dict, Iterable, Optional, Tuple, Union import numpy as np +import torch from transformers import AutoTokenizer from modelscope.metainfo import Models, Preprocessors from modelscope.outputs import OutputKeys +from modelscope.utils.config import ConfigFields from modelscope.utils.constant import Fields, InputFields, ModeKeys from modelscope.utils.hub import get_model_type, parse_label_mapping from modelscope.utils.type_assert import type_assert @@ -21,7 +23,7 @@ __all__ = [ 'PairSentenceClassificationPreprocessor', 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', - 'TextErrorCorrectionPreprocessor' + 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' ] @@ -645,3 +647,86 @@ class TextErrorCorrectionPreprocessor(Preprocessor): sample = dict() sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} return sample + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) +class FaqQuestionAnsweringPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + super(FaqQuestionAnsweringPreprocessor, self).__init__( + model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs) + import os + from transformers import BertTokenizer + + from modelscope.utils.config import Config + from modelscope.utils.constant import ModelFile + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + preprocessor_config = Config.from_file( + os.path.join(model_dir, ModelFile.CONFIGURATION)).get( + ConfigFields.preprocessor, {}) + self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) + self.label_dict = None + + def pad(self, samples, max_len): + result = [] + for sample in samples: + pad_len = max_len - len(sample[:max_len]) + result.append(sample[:max_len] + + [self.tokenizer.pad_token_id] * pad_len) + return result + + def set_label_dict(self, label_dict): + self.label_dict = label_dict + + def get_label(self, label_id): + assert self.label_dict is not None and label_id < len(self.label_dict) + return self.label_dict[label_id] + + def encode_plus(self, text): + return [ + self.tokenizer.cls_token_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] + + @type_assert(object, Dict) + def __call__(self, data: Dict[str, Any], + **preprocessor_param) -> Dict[str, Any]: + TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) + queryset = data['query_set'] + if not isinstance(queryset, list): + queryset = [queryset] + supportset = data['support_set'] + supportset = sorted(supportset, key=lambda d: d['label']) + + queryset_tokenized = [self.encode_plus(text) for text in queryset] + supportset_tokenized = [ + self.encode_plus(item['text']) for item in supportset + ] + + max_len = max( + [len(seq) for seq in queryset_tokenized + supportset_tokenized]) + max_len = min(TMP_MAX_LEN, max_len) + queryset_padded = self.pad(queryset_tokenized, max_len) + supportset_padded = self.pad(supportset_tokenized, max_len) + + supportset_labels_ori = [item['label'] for item in supportset] + label_dict = [] + for label in supportset_labels_ori: + if label not in label_dict: + label_dict.append(label) + self.set_label_dict(label_dict) + supportset_labels_ids = [ + label_dict.index(label) for label in supportset_labels_ori + ] + return { + 'query': queryset_padded, + 'support': supportset_padded, + 'support_labels': supportset_labels_ids + } + + def batch_encode(self, sentence_list: list, max_length=None): + if not max_length: + max_length = self.MAX_LEN + return self.tokenizer.batch_encode_plus( + sentence_list, padding=True, max_length=max_length) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4ef34812..52c08594 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -95,6 +95,7 @@ class NLPTasks(object): zero_shot_classification = 'zero-shot-classification' backbone = 'backbone' text_error_correction = 'text-error-correction' + faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' diff --git a/tests/pipelines/test_faq_question_answering.py b/tests/pipelines/test_faq_question_answering.py new file mode 100644 index 00000000..3a87643c --- /dev/null +++ b/tests/pipelines/test_faq_question_answering.py @@ -0,0 +1,85 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import numpy as np + +from modelscope.hub.api import HubApi +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForFaqQuestionAnswering +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline +from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaqQuestionAnsweringTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base' + param = { + 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], + 'support_set': [{ + 'text': '卖品代金券怎么用', + 'label': '6527856' + }, { + 'text': '怎么使用优惠券', + 'label': '6527856' + }, { + 'text': '这个可以一起领吗', + 'label': '1000012000' + }, { + 'text': '付款时送的优惠券哪里领', + 'label': '1000012000' + }, { + 'text': '购物等级怎么长', + 'label': '13421097' + }, { + 'text': '购物等级二心', + 'label': '13421097' + }] + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + preprocessor = FaqQuestionAnsweringPreprocessor(cache_path) + model = SbertForFaqQuestionAnswering(cache_path) + model.load_checkpoint(cache_path) + pipeline_ins = FaqQuestionAnsweringPipeline( + model, preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.faq_question_answering, + model=model, + preprocessor=preprocessor) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.faq_question_answering, model=self.model_id) + result = pipeline_ins(self.param) + print(result) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.faq_question_answering) + print(pipeline_ins(self.param, max_seq_length=20)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_sentence_embedding(self): + pipeline_ins = pipeline(task=Tasks.faq_question_answering) + sentence_vec = pipeline_ins.get_sentence_embedding( + ['今天星期六', '明天星期几明天星期几']) + print(np.shape(sentence_vec)) + + +if __name__ == '__main__': + unittest.main()