Browse Source

[to #42322933]新增FAQ问答模型

Maas新增FAQ问答模型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9797053
master
tanfan.zjh yingda.chen 3 years ago
parent
commit
dc45fce542
13 changed files with 526 additions and 8 deletions
  1. +2
    -2
      modelscope/hub/errors.py
  2. +2
    -0
      modelscope/metainfo.py
  3. +2
    -1
      modelscope/models/cv/skin_retouching/retinaface/box_utils.py
  4. +2
    -0
      modelscope/models/nlp/__init__.py
  5. +249
    -0
      modelscope/models/nlp/sbert_for_faq_question_answering.py
  6. +11
    -0
      modelscope/outputs.py
  7. +3
    -1
      modelscope/pipelines/builder.py
  8. +3
    -1
      modelscope/pipelines/nlp/__init__.py
  9. +76
    -0
      modelscope/pipelines/nlp/faq_question_answering_pipeline.py
  10. +4
    -2
      modelscope/preprocessors/__init__.py
  11. +86
    -1
      modelscope/preprocessors/nlp.py
  12. +1
    -0
      modelscope/utils/constant.py
  13. +85
    -0
      tests/pipelines/test_faq_question_answering.py

+ 2
- 2
modelscope/hub/errors.py View File

@@ -49,8 +49,8 @@ def handle_http_response(response, logger, cookies, model_id):
except HTTPError:
if cookies is None: # code in [403] and
logger.error(
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \
Please login first.')
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
private. Please login first.')
raise




+ 2
- 0
modelscope/metainfo.py View File

@@ -138,6 +138,7 @@ class Pipelines(object):
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'

# audio tasks
@@ -220,6 +221,7 @@ class Preprocessors(object):
text_error_correction = 'text-error-correction'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
fill_mask = 'fill-mask'
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'
conversational_text_to_sql = 'conversational-text-to-sql'

# audio preprocessor


+ 2
- 1
modelscope/models/cv/skin_retouching/retinaface/box_utils.py View File

@@ -6,7 +6,8 @@ import torch


def point_form(boxes: torch.Tensor) -> torch.Tensor:
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data.
"""Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \
ground truth data.

Args:
boxes: center-size default boxes from priorbox layers.


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .task_models.task_model import SingleBackboneTaskModelBase
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .gpt3 import GPT3ForTextGeneration
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering

else:
_import_structure = {
@@ -44,6 +45,7 @@ else:
'task_model': ['SingleBackboneTaskModelBase'],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
'gpt3': ['GPT3ForTextGeneration'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering']
}

import sys


+ 249
- 0
modelscope/models/nlp/sbert_for_faq_question_answering.py View File

@@ -0,0 +1,249 @@
import math
import os
from collections import namedtuple
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.models.nlp.structbert import SbertConfig, SbertModel
from modelscope.models.nlp.task_models.task_model import BaseTaskModel
from modelscope.utils.config import Config, ConfigFields
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['SbertForFaqQuestionAnswering']


class SbertForFaqQuestionAnsweringBase(BaseTaskModel):
"""base class for faq models
"""

def __init__(self, model_dir, *args, **kwargs):
super(SbertForFaqQuestionAnsweringBase,
self).__init__(model_dir, *args, **kwargs)

backbone_cfg = SbertConfig.from_pretrained(model_dir)
self.bert = SbertModel(backbone_cfg)

model_config = Config.from_file(
os.path.join(model_dir,
ModelFile.CONFIGURATION)).get(ConfigFields.model, {})

metric = model_config.get('metric', 'cosine')
pooling_method = model_config.get('pooling', 'avg')

Arg = namedtuple('args', [
'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling'
])
args = Arg(
metrics=metric,
proj_hidden_size=self.bert.config.hidden_size,
hidden_size=self.bert.config.hidden_size,
dropout=0.0,
pooling=pooling_method)

self.metrics_layer = MetricsLayer(args)
self.pooling = PoolingLayer(args)

def _get_onehot_labels(self, labels, support_size, num_cls):
labels_ = labels.view(support_size, 1)
target_oh = torch.zeros(support_size, num_cls).to(labels)
target_oh.scatter_(dim=1, index=labels_, value=1)
return target_oh.view(support_size, num_cls).float()

def forward_sentence_embedding(self, inputs: Dict[str, Tensor]):
input_ids = inputs['input_ids']
input_mask = inputs['attention_mask']
if not isinstance(input_ids, Tensor):
input_ids = torch.IntTensor(input_ids)
if not isinstance(input_mask, Tensor):
input_mask = torch.IntTensor(input_mask)
rst = self.bert(input_ids, input_mask)
last_hidden_states = rst.last_hidden_state
if len(input_mask.shape) == 2:
input_mask = input_mask.unsqueeze(-1)
pooled_representation = self.pooling(last_hidden_states, input_mask)
return pooled_representation


@MODELS.register_module(
Tasks.faq_question_answering, module_name=Models.structbert)
class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase):
_backbone_prefix = ''

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
assert not self.training
query = input['query']
support = input['support']
if isinstance(query, list):
query = torch.stack(query)
if isinstance(support, list):
support = torch.stack(support)
n_query = query.shape[0]
n_support = support.shape[0]
query_mask = torch.ne(query, 0).view([n_query, -1])
support_mask = torch.ne(support, 0).view([n_support, -1])

support_labels = input['support_labels']
num_cls = torch.max(support_labels) + 1
onehot_labels = self._get_onehot_labels(support_labels, n_support,
num_cls)

input_ids = torch.cat([query, support])
input_mask = torch.cat([query_mask, support_mask], dim=0)
pooled_representation = self.forward_sentence_embedding({
'input_ids':
input_ids,
'attention_mask':
input_mask
})
z_query = pooled_representation[:n_query]
z_support = pooled_representation[n_query:]
cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5
protos = torch.matmul(onehot_labels.transpose(0, 1),
z_support) / cls_n_support.unsqueeze(-1)
scores = self.metrics_layer(z_query, protos).view([n_query, num_cls])
if self.metrics_layer.name == 'relation':
scores = torch.sigmoid(scores)
return {'scores': scores}


activations = {
'relu': F.relu,
'tanh': torch.tanh,
'linear': lambda x: x,
}

activation_coeffs = {
'relu': math.sqrt(2),
'tanh': 5 / 3,
'linear': 1.,
}


class LinearProjection(nn.Module):

def __init__(self,
in_features,
out_features,
activation='linear',
bias=True):
super().__init__()
self.activation = activations[activation]
activation_coeff = activation_coeffs[activation]
linear = nn.Linear(in_features, out_features, bias=bias)
nn.init.normal_(
linear.weight, std=math.sqrt(1. / in_features) * activation_coeff)
if bias:
nn.init.zeros_(linear.bias)
self.model = nn.utils.weight_norm(linear)

def forward(self, x):
return self.activation(self.model(x))


class RelationModule(nn.Module):

def __init__(self, args):
super(RelationModule, self).__init__()
input_size = args.proj_hidden_size * 4
self.prediction = torch.nn.Sequential(
LinearProjection(
input_size, args.proj_hidden_size * 4, activation='relu'),
nn.Dropout(args.dropout),
LinearProjection(args.proj_hidden_size * 4, 1))

def forward(self, query, protos):
n_cls = protos.shape[0]
n_query = query.shape[0]
protos = protos.unsqueeze(0).repeat(n_query, 1, 1)
query = query.unsqueeze(1).repeat(1, n_cls, 1)
input_feat = torch.cat(
[query, protos, (protos - query).abs(), query * protos], dim=-1)
dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1]
return dists.squeeze(-1)


class MetricsLayer(nn.Module):

def __init__(self, args):
super(MetricsLayer, self).__init__()
self.args = args
assert args.metrics in ('relation', 'cosine')
if args.metrics == 'relation':
self.relation_net = RelationModule(args)

@property
def name(self):
return self.args.metrics

def forward(self, query, protos):
""" query : [bsz, n_query, dim]
support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim]
"""
if self.args.metrics == 'cosine':
supervised_dists = self.cosine_similarity(query, protos)
if self.training:
supervised_dists *= 5
elif self.args.metrics in ('relation', ):
supervised_dists = self.relation_net(query, protos)
else:
raise NotImplementedError
return supervised_dists

def cosine_similarity(self, x, y):
# x=[bsz, n_query, dim]
# y=[bsz, n_cls, dim]
n_query = x.shape[0]
n_cls = y.shape[0]
dim = x.shape[-1]
x = x.unsqueeze(1).expand([n_query, n_cls, dim])
y = y.unsqueeze(0).expand([n_query, n_cls, dim])
return F.cosine_similarity(x, y, -1)


class AveragePooling(nn.Module):

def forward(self, x, mask, dim=1):
return torch.sum(
x * mask.float(), dim=dim) / torch.sum(
mask.float(), dim=dim)


class AttnPooling(nn.Module):

def __init__(self, input_size, hidden_size=None, output_size=None):
super().__init__()
self.input_proj = nn.Sequential(
LinearProjection(input_size, hidden_size), nn.Tanh(),
LinearProjection(hidden_size, 1, bias=False))
self.output_proj = LinearProjection(
input_size, output_size) if output_size else lambda x: x

def forward(self, x, mask):
score = self.input_proj(x)
score = score * mask.float() + -1e4 * (1. - mask.float())
score = F.softmax(score, dim=1)
features = self.output_proj(x)
return torch.matmul(score.transpose(1, 2), features).squeeze(1)


class PoolingLayer(nn.Module):

def __init__(self, args):
super(PoolingLayer, self).__init__()
if args.pooling == 'attn':
self.pooling = AttnPooling(args.proj_hidden_size,
args.proj_hidden_size,
args.proj_hidden_size)
elif args.pooling == 'avg':
self.pooling = AveragePooling()
else:
raise NotImplementedError(args.pooling)

def forward(self, x, mask):
return self.pooling(x, mask)

+ 11
- 0
modelscope/outputs.py View File

@@ -7,6 +7,7 @@ class OutputKeys(object):
LOSS = 'loss'
LOGITS = 'logits'
SCORES = 'scores'
SCORE = 'score'
LABEL = 'label'
LABELS = 'labels'
INPUT_IDS = 'input_ids'
@@ -504,6 +505,16 @@ TASK_OUTPUTS = {
# }
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS],

# {
# 'output': [
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509},
# {'label': '13421097', 'score': 2.2825044965202324e-08}],
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402},
# {'label': '13421097', 'score': 2.75914817393641e-06}],
# [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402},
# {'label': '13421097', 'score': 2.75914817393641e-06}]]
# }
Tasks.faq_question_answering: [OutputKeys.OUTPUT],
# image person reid result for single sample
# {
# "img_embedding": np.array with shape [1, D],


+ 3
- 1
modelscope/pipelines/builder.py View File

@@ -129,6 +129,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
Tasks.skin_retouching: (Pipelines.skin_retouching,
'damo/cv_unet_skin-retouching'),
Tasks.faq_question_answering:
(Pipelines.faq_question_answering,
'damo/nlp_structbert_faq-question-answering_chinese-base'),
Tasks.crowd_counting: (Pipelines.crowd_counting,
'damo/cv_hrnet_crowd-counting_dcanet'),
Tasks.video_single_object_tracking:
@@ -218,7 +221,6 @@ def pipeline(task: str = None,
f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}'

model = normalize_model_input(model, model_revision)

if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \


+ 3
- 1
modelscope/pipelines/nlp/__init__.py View File

@@ -20,6 +20,7 @@ if TYPE_CHECKING:
from .summarization_pipeline import SummarizationPipeline
from .text_classification_pipeline import TextClassificationPipeline
from .text_error_correction_pipeline import TextErrorCorrectionPipeline
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline

else:
_import_structure = {
@@ -44,7 +45,8 @@ else:
'translation_pipeline': ['TranslationPipeline'],
'summarization_pipeline': ['SummarizationPipeline'],
'text_classification_pipeline': ['TextClassificationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline']
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'],
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline']
}

import sys


+ 76
- 0
modelscope/pipelines/nlp/faq_question_answering_pipeline.py View File

@@ -0,0 +1,76 @@
from typing import Any, Dict, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import SbertForFaqQuestionAnswering
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor
from modelscope.utils.constant import Tasks

__all__ = ['FaqQuestionAnsweringPipeline']


@PIPELINES.register_module(
Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering)
class FaqQuestionAnsweringPipeline(Pipeline):

def __init__(self,
model: Union[str, SbertForFaqQuestionAnswering],
preprocessor: FaqQuestionAnsweringPreprocessor = None,
**kwargs):
model = model if isinstance(
model,
SbertForFaqQuestionAnswering) else Model.from_pretrained(model)
model.eval()
if preprocessor is None:
preprocessor = FaqQuestionAnsweringPreprocessor(
model.model_dir, **kwargs)
self.preprocessor = preprocessor
super(FaqQuestionAnsweringPipeline, self).__init__(
model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **pipeline_parameters):
return pipeline_parameters, pipeline_parameters, pipeline_parameters

def get_sentence_embedding(self, inputs, max_len=None):
inputs = self.preprocessor.batch_encode(inputs, max_length=max_len)
sentence_vecs = self.model.forward_sentence_embedding(inputs)
sentence_vecs = sentence_vecs.detach().tolist()
return sentence_vecs

def forward(self, inputs: [list, Dict[str, Any]],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs)

def postprocess(self, inputs: [list, Dict[str, Any]],
**postprocess_params) -> Dict[str, Any]:
scores = inputs['scores']
labels = []
for item in scores:
tmplabels = [
self.preprocessor.get_label(label_id)
for label_id in range(len(item))
]
labels.append(tmplabels)

predictions = []
for tmp_scores, tmp_labels in zip(scores.tolist(), labels):
prediction = []
for score, label in zip(tmp_scores, tmp_labels):
prediction.append({
OutputKeys.LABEL: label,
OutputKeys.SCORE: score
})
predictions.append(
list(
sorted(
prediction,
key=lambda d: d[OutputKeys.SCORE],
reverse=True)))

return {OutputKeys.OUTPUT: predictions}

+ 4
- 2
modelscope/preprocessors/__init__.py View File

@@ -21,7 +21,8 @@ if TYPE_CHECKING:
SingleSentenceClassificationPreprocessor,
PairSentenceClassificationPreprocessor,
FillMaskPreprocessor, ZeroShotClassificationPreprocessor,
NERPreprocessor, TextErrorCorrectionPreprocessor)
NERPreprocessor, TextErrorCorrectionPreprocessor,
FaqQuestionAnsweringPreprocessor)
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
@@ -48,7 +49,8 @@ else:
'SingleSentenceClassificationPreprocessor',
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor'
'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor'
],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',


+ 86
- 1
modelscope/preprocessors/nlp.py View File

@@ -5,10 +5,12 @@ import uuid
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
import torch
from transformers import AutoTokenizer

from modelscope.metainfo import Models, Preprocessors
from modelscope.outputs import OutputKeys
from modelscope.utils.config import ConfigFields
from modelscope.utils.constant import Fields, InputFields, ModeKeys
from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.type_assert import type_assert
@@ -21,7 +23,7 @@ __all__ = [
'PairSentenceClassificationPreprocessor',
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor'
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor'
]


@@ -645,3 +647,86 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
sample = dict()
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths}
return sample


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor)
class FaqQuestionAnsweringPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
super(FaqQuestionAnsweringPreprocessor, self).__init__(
model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs)
import os
from transformers import BertTokenizer

from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
preprocessor_config = Config.from_file(
os.path.join(model_dir, ModelFile.CONFIGURATION)).get(
ConfigFields.preprocessor, {})
self.MAX_LEN = preprocessor_config.get('max_seq_length', 50)
self.label_dict = None

def pad(self, samples, max_len):
result = []
for sample in samples:
pad_len = max_len - len(sample[:max_len])
result.append(sample[:max_len]
+ [self.tokenizer.pad_token_id] * pad_len)
return result

def set_label_dict(self, label_dict):
self.label_dict = label_dict

def get_label(self, label_id):
assert self.label_dict is not None and label_id < len(self.label_dict)
return self.label_dict[label_id]

def encode_plus(self, text):
return [
self.tokenizer.cls_token_id
] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id]

@type_assert(object, Dict)
def __call__(self, data: Dict[str, Any],
**preprocessor_param) -> Dict[str, Any]:
TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN)
queryset = data['query_set']
if not isinstance(queryset, list):
queryset = [queryset]
supportset = data['support_set']
supportset = sorted(supportset, key=lambda d: d['label'])

queryset_tokenized = [self.encode_plus(text) for text in queryset]
supportset_tokenized = [
self.encode_plus(item['text']) for item in supportset
]

max_len = max(
[len(seq) for seq in queryset_tokenized + supportset_tokenized])
max_len = min(TMP_MAX_LEN, max_len)
queryset_padded = self.pad(queryset_tokenized, max_len)
supportset_padded = self.pad(supportset_tokenized, max_len)

supportset_labels_ori = [item['label'] for item in supportset]
label_dict = []
for label in supportset_labels_ori:
if label not in label_dict:
label_dict.append(label)
self.set_label_dict(label_dict)
supportset_labels_ids = [
label_dict.index(label) for label in supportset_labels_ori
]
return {
'query': queryset_padded,
'support': supportset_padded,
'support_labels': supportset_labels_ids
}

def batch_encode(self, sentence_list: list, max_length=None):
if not max_length:
max_length = self.MAX_LEN
return self.tokenizer.batch_encode_plus(
sentence_list, padding=True, max_length=max_length)

+ 1
- 0
modelscope/utils/constant.py View File

@@ -95,6 +95,7 @@ class NLPTasks(object):
zero_shot_classification = 'zero-shot-classification'
backbone = 'backbone'
text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'




+ 85
- 0
tests/pipelines/test_faq_question_answering.py View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import numpy as np

from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForFaqQuestionAnswering
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline
from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class FaqQuestionAnsweringTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base'
param = {
'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'],
'support_set': [{
'text': '卖品代金券怎么用',
'label': '6527856'
}, {
'text': '怎么使用优惠券',
'label': '6527856'
}, {
'text': '这个可以一起领吗',
'label': '1000012000'
}, {
'text': '付款时送的优惠券哪里领',
'label': '1000012000'
}, {
'text': '购物等级怎么长',
'label': '13421097'
}, {
'text': '购物等级二心',
'label': '13421097'
}]
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
preprocessor = FaqQuestionAnsweringPreprocessor(cache_path)
model = SbertForFaqQuestionAnswering(cache_path)
model.load_checkpoint(cache_path)
pipeline_ins = FaqQuestionAnsweringPipeline(
model, preprocessor=preprocessor)
result = pipeline_ins(self.param)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.faq_question_answering,
model=model,
preprocessor=preprocessor)
result = pipeline_ins(self.param)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.faq_question_answering, model=self.model_id)
result = pipeline_ins(self.param)
print(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.faq_question_answering)
print(pipeline_ins(self.param, max_seq_length=20))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_sentence_embedding(self):
pipeline_ins = pipeline(task=Tasks.faq_question_answering)
sentence_vec = pipeline_ins.get_sentence_embedding(
['今天星期六', '明天星期几明天星期几'])
print(np.shape(sentence_vec))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save