Browse Source

[to #42322933] add new Task - document segmentation

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9942858

    * [Add] add document-segmentation
master
shichen.fsc yingda.chen 3 years ago
parent
commit
2b380f0410
9 changed files with 572 additions and 0 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +108
    -0
      modelscope/models/nlp/bert_for_document_segmentation.py
  4. +2
    -0
      modelscope/pipelines/nlp/__init__.py
  5. +175
    -0
      modelscope/pipelines/nlp/document_segmentation_pipeline.py
  6. +2
    -0
      modelscope/preprocessors/__init__.py
  7. +223
    -0
      modelscope/preprocessors/slp.py
  8. +1
    -0
      modelscope/utils/constant.py
  9. +56
    -0
      tests/pipelines/test_document_segmentation.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -45,6 +45,7 @@ class Models(object):
tcrf = 'transformer-crf'
bart = 'bart'
gpt3 = 'gpt3'
bert_for_ds = 'bert-for-document-segmentation'

# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -151,6 +152,7 @@ class Pipelines(object):
text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
document_segmentation = 'document-segmentation'

# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -240,6 +242,7 @@ class Preprocessors(object):
fill_mask = 'fill-mask'
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'
conversational_text_to_sql = 'conversational-text-to-sql'
document_segmentation = 'document-segmentation'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .backbones import SbertModel
from .heads import SequenceClassificationHead
from .bert_for_sequence_classification import BertForSequenceClassification
from .bert_for_document_segmentation import BertForDocumentSegmentation
from .csanmt_for_translation import CsanmtForTranslation
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM,
BertForMaskedLM)
@@ -30,6 +31,7 @@ else:
'heads': ['SequenceClassificationHead'],
'csanmt_for_translation': ['CsanmtForTranslation'],
'bert_for_sequence_classification': ['BertForSequenceClassification'],
'bert_for_document_segmentation': ['BertForDocumentSegmentation'],
'masked_language':
['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'],
'nncrf_for_named_entity_recognition':


+ 108
- 0
modelscope/models/nlp/bert_for_document_segmentation.py View File

@@ -0,0 +1,108 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.bert.modeling_bert import (BertModel,
BertPreTrainedModel)

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['BertForDocumentSegmentation']


@MODELS.register_module(
Tasks.document_segmentation, module_name=Models.bert_for_ds)
class BertForDocumentSegmentation(Model):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)

def build_with_config(self, config):
self.bert_model = BertForDocumentSegmentationBase.from_pretrained(
self.model_dir, from_tf=False, config=config)
return self.bert_model

def forward(self, input: Dict[str, Dict]) -> Dict[str, Any]:
pass


class BertForDocumentSegmentationBase(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r'pooler']

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.sentence_pooler_type = None
self.bert = BertModel(config, add_pooling_layer=False)

classifier_dropout = config.hidden_dropout_prob
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.class_weights = None
self.init_weights()

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
sentence_attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
if self.sentence_pooler_type is not None:
raise NotImplementedError
else:
sequence_output = self.dropout(sequence_output)

logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(weight=self.class_weights)
if sentence_attention_mask is not None:
active_loss = sentence_attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1),
torch.tensor(loss_fct.ignore_index).type_as(labels))
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(
logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits, ) + outputs[2:]
return ((loss, ) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

+ 2
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -8,6 +8,7 @@ if TYPE_CHECKING:
from .dialog_intent_prediction_pipeline import DialogIntentPredictionPipeline
from .dialog_modeling_pipeline import DialogModelingPipeline
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
from .document_segmentation_pipeline import DocumentSegmentationPipeline
from .fill_mask_pipeline import FillMaskPipeline
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline
@@ -30,6 +31,7 @@ else:
['DialogIntentPredictionPipeline'],
'dialog_modeling_pipeline': ['DialogModelingPipeline'],
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'],
'document_segmentation_pipeline': ['DocumentSegmentationPipeline'],
'fill_mask_pipeline': ['FillMaskPipeline'],
'single_sentence_classification_pipeline':
['SingleSentenceClassificationPipeline'],


+ 175
- 0
modelscope/pipelines/nlp/document_segmentation_pipeline.py View File

@@ -0,0 +1,175 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import re
from typing import Any, Dict, List, Union

import numpy as np
import torch
from datasets import Dataset
from transformers.models.bert.modeling_bert import BertConfig

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import DocumentSegmentationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['DocumentSegmentationPipeline']


@PIPELINES.register_module(
Tasks.document_segmentation, module_name=Pipelines.document_segmentation)
class DocumentSegmentationPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: DocumentSegmentationPreprocessor = None,
**kwargs):

model = model if isinstance(model,
Model) else Model.from_pretrained(model)

self.model_dir = model.model_dir
config = BertConfig.from_pretrained(model.model_dir, num_labels=2)

self.document_segmentation_model = model.build_with_config(
config=config)

if preprocessor is None:
preprocessor = DocumentSegmentationPreprocessor(
self.model_dir, config)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

self.preprocessor = preprocessor

def __call__(self, documents: Union[List[str], str]) -> Dict[str, Any]:
output = self.predict(documents)
output = self.postprocess(output)
return output

def predict(self, documents: Union[List[str], str]) -> Dict[str, Any]:
pred_samples = self.cut_documents(documents)
predict_examples = Dataset.from_dict(pred_samples)

# Predict Feature Creation
predict_dataset = self.preprocessor(predict_examples)
num_examples = len(
predict_examples[self.preprocessor.context_column_name])
num_samples = len(
predict_dataset[self.preprocessor.context_column_name])

predict_dataset.pop('segment_ids')
labels = predict_dataset.pop('labels')
sentences = predict_dataset.pop('sentences')
example_ids = predict_dataset.pop(
self.preprocessor.example_id_column_name)

with torch.no_grad():
input = {
key: torch.tensor(val)
for key, val in predict_dataset.items()
}
predictions = self.document_segmentation_model.forward(
**input).logits

predictions = np.argmax(predictions, axis=2)
assert len(sentences) == len(
predictions), 'sample {} infer_sample {} prediction {}'.format(
num_samples, len(sentences), len(predictions))
# Remove ignored index (special tokens)
true_predictions = [
[
self.preprocessor.label_list[p]
for (p, l) in zip(prediction, label) if l != -100 # noqa *
] for prediction, label in zip(predictions, labels)
]

true_labels = [
[
self.preprocessor.label_list[l]
for (p, l) in zip(prediction, label) if l != -100 # noqa *
] for prediction, label in zip(predictions, labels)
]

# Save predictions
out = []
for i in range(num_examples):
out.append({'sentences': [], 'labels': [], 'predictions': []})

for prediction, sentence_list, label, example_id in zip(
true_predictions, sentences, true_labels, example_ids):
if len(label) < len(sentence_list):
label.append('B-EOP')
prediction.append('B-EOP')
assert len(sentence_list) == len(prediction), '{} {}'.format(
len(sentence_list), len(prediction))
assert len(sentence_list) == len(label), '{} {}'.format(
len(sentence_list), len(label))
out[example_id]['sentences'].extend(sentence_list)
out[example_id]['labels'].extend(label)
out[example_id]['predictions'].extend(prediction)

return out

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
result = []
list_count = len(inputs)
for num in range(list_count):
res = []
for s, p in zip(inputs[num]['sentences'],
inputs[num]['predictions']):
s = s.strip()
if p == 'B-EOP':
s = ''.join([s, '\n\t'])
res.append(s)

document = ('\t' + ''.join(res))
result.append(document)

if list_count == 1:
return {OutputKeys.TEXT: result[0]}
else:
return {OutputKeys.TEXT: result}

def cut_documents(self, para: Union[List[str], str]):
document_list = para
if isinstance(para, str):
document_list = [para]
sentences = []
labels = []
example_id = []
id = 0
for document in document_list:
sentence = self.cut_sentence(document)
label = ['O'] * (len(sentence) - 1) + ['B-EOP']
sentences.append(sentence)
labels.append(label)
example_id.append(id)
id += 1

return {
'example_id': example_id,
'sentences': sentences,
'labels': labels
}

def cut_sentence(self, para):
para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa *
para = para.rstrip()
return [_ for _ in para.split('\n') if _]

+ 2
- 0
modelscope/preprocessors/__init__.py View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
FillMaskPreprocessor, ZeroShotClassificationPreprocessor,
NERPreprocessor, TextErrorCorrectionPreprocessor,
FaqQuestionAnsweringPreprocessor)
from .slp import DocumentSegmentationPreprocessor
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
@@ -52,6 +53,7 @@ else:
'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor'
],
'slp': ['DocumentSegmentationPreprocessor'],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
'DialogStateTrackingPreprocessor', 'InputFeatures'


+ 223
- 0
modelscope/preprocessors/slp.py View File

@@ -0,0 +1,223 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from transformers import BertTokenizerFast

from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields
from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['DocumentSegmentationPreprocessor']


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.document_segmentation)
class DocumentSegmentationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, config, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir,
use_fast=True,
)
self.question_column_name = 'labels'
self.context_column_name = 'sentences'
self.example_id_column_name = 'example_id'
self.label_to_id = {'B-EOP': 0, 'O': 1}
self.target_specical_ids = set()
self.target_specical_ids.add(self.tokenizer.eos_token_id)
self.max_seq_length = config.max_position_embeddings
self.label_list = ['B-EOP', 'O']

def __call__(self, examples) -> Dict[str, Any]:
questions = examples[self.question_column_name]
contexts = examples[self.context_column_name]
example_ids = examples[self.example_id_column_name]
num_examples = len(questions)

sentences = []
for sentence_list in contexts:
sentence_list = [_ + '[EOS]' for _ in sentence_list]
sentences.append(sentence_list)

try:
tokenized_examples = self.tokenizer(
sentences,
is_split_into_words=True,
add_special_tokens=False,
return_token_type_ids=True,
return_attention_mask=True,
)
except Exception as e:
print(str(e))
return {}

segment_ids = []
token_seq_labels = []
for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_labels = questions[example_index]
example_labels = [
self.label_to_id[_] if _ in self.label_to_id else -100
for _ in example_labels
]
example_token_labels = []
segment_id = []
cur_seg_id = 1
for token_index in range(len(example_input_ids)):
if example_input_ids[token_index] in self.target_specical_ids:
example_token_labels.append(example_labels[cur_seg_id - 1])
segment_id.append(cur_seg_id)
cur_seg_id += 1
else:
example_token_labels.append(-100)
segment_id.append(cur_seg_id)

segment_ids.append(segment_id)
token_seq_labels.append(example_token_labels)

tokenized_examples['segment_ids'] = segment_ids
tokenized_examples['token_seq_labels'] = token_seq_labels

new_segment_ids = []
new_token_seq_labels = []
new_input_ids = []
new_token_type_ids = []
new_attention_mask = []
new_example_ids = []
new_sentences = []

for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_token_type_ids = tokenized_examples['token_type_ids'][
example_index]
example_attention_mask = tokenized_examples['attention_mask'][
example_index]
example_segment_ids = tokenized_examples['segment_ids'][
example_index]
example_token_seq_labels = tokenized_examples['token_seq_labels'][
example_index]
example_sentences = contexts[example_index]
example_id = example_ids[example_index]
example_total_num_sentences = len(questions[example_index])
example_total_num_tokens = len(
tokenized_examples['input_ids'][example_index])
accumulate_length = [
i for i, x in enumerate(tokenized_examples['input_ids']
[example_index])
if x == self.tokenizer.eos_token_id
]
samples_boundary = []
left_index = 0
sent_left_index = 0
sent_i = 0

# for sent_i, length in enumerate(accumulate_length):
while sent_i < len(accumulate_length):
length = accumulate_length[sent_i]
right_index = length + 1
sent_right_index = sent_i + 1
if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens:
samples_boundary.append([left_index, right_index])

sample_input_ids = [
self.tokenizer.cls_token_id
] + example_input_ids[left_index:right_index]
sample_input_ids = sample_input_ids[:self.max_seq_length]

sample_token_type_ids = [
0
] + example_token_type_ids[left_index:right_index]
sample_token_type_ids = sample_token_type_ids[:self.
max_seq_length]

sample_attention_mask = [
1
] + example_attention_mask[left_index:right_index]
sample_attention_mask = sample_attention_mask[:self.
max_seq_length]

sample_segment_ids = [
0
] + example_segment_ids[left_index:right_index]
sample_segment_ids = sample_segment_ids[:self.
max_seq_length]

sample_token_seq_labels = [
-100
] + example_token_seq_labels[left_index:right_index]
sample_token_seq_labels = sample_token_seq_labels[:self.
max_seq_length]

if sent_right_index - 1 == sent_left_index:
left_index = right_index
sample_input_ids[-1] = self.tokenizer.eos_token_id
sample_token_seq_labels[-1] = -100
else:
left_index = accumulate_length[sent_i - 1] + 1
if sample_token_seq_labels[-1] != -100:
sample_token_seq_labels[-1] = -100

if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens:
sample_sentences = example_sentences[
sent_left_index:sent_right_index]
sent_left_index = sent_right_index
sent_i += 1
else:
sample_sentences = example_sentences[
sent_left_index:sent_right_index - 1]
sent_left_index = sent_right_index - 1

if (len([_ for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences) - 1 and (len([
_
for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences):
tmp = []
for w_i, w, l in zip(
sample_input_ids,
self.tokenizer.decode(sample_input_ids).split(
' '), sample_token_seq_labels):
tmp.append((w_i, w, l))
while len(sample_input_ids) < self.max_seq_length:
sample_input_ids.append(self.tokenizer.pad_token_id)
sample_token_type_ids.append(0)
sample_attention_mask.append(0)
sample_segment_ids.append(example_total_num_sentences
+ 1)
sample_token_seq_labels.append(-100)

new_input_ids.append(sample_input_ids)
new_token_type_ids.append(sample_token_type_ids)
new_attention_mask.append(sample_attention_mask)
new_segment_ids.append(sample_segment_ids)
new_token_seq_labels.append(sample_token_seq_labels)
new_example_ids.append(example_id)
new_sentences.append(sample_sentences)
else:
sent_i += 1
continue

output_samples = {}

output_samples['input_ids'] = new_input_ids
output_samples['token_type_ids'] = new_token_type_ids
output_samples['attention_mask'] = new_attention_mask

output_samples['segment_ids'] = new_segment_ids
output_samples['example_id'] = new_example_ids
output_samples['labels'] = new_token_seq_labels
output_samples['sentences'] = new_sentences

return output_samples

+ 1
- 0
modelscope/utils/constant.py View File

@@ -98,6 +98,7 @@ class NLPTasks(object):
text_error_correction = 'text-error-correction'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
document_segmentation = 'document-segmentation'


class AudioTasks(object):


+ 56
- 0
tests/pipelines/test_document_segmentation.py View File

@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest
from typing import Any, Dict

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


class DocumentSegmentationTest(unittest.TestCase):

model_id = 'damo/nlp_bert_document-segmentation_chinese-base'
eng_model_id = 'damo/nlp_bert_document-segmentation_english-base'
sentences = '近年来,随着端到端语音识别的流行,基于Transformer结构的语音识别系统逐渐成为了主流。然而,由于Transformer是一种自回归模型,需要逐个生成目标文字,计算复杂度随着目标文字数量线性增加,限制了其在工业生产中的应用。针对Transoformer模型自回归生成文字的低计算效率缺陷,学术界提出了非自回归模型来并行的输出目标文字。根据生成目标文字时,迭代轮数,非自回归模型分为:多轮迭代式与单轮迭代非自回归模型。其中实用的是基于单轮迭代的非自回归模型。对于单轮非自回归模型,现有工作往往聚焦于如何更加准确的预测目标文字个数,如CTC-enhanced采用CTC预测输出文字个数,尽管如此,考虑到现实应用中,语速、口音、静音以及噪声等因素的影响,如何准确的预测目标文字个数以及抽取目标文字对应的声学隐变量仍然是一个比较大的挑战;另外一方面,我们通过对比自回归模型与单轮非自回归模型在工业大数据上的错误类型(如下图所示,AR与vanilla NAR),发现,相比于自回归模型,非自回归模型,在预测目标文字个数方面差距较小,但是替换错误显著的增加,我们认为这是由于单轮非自回归模型中条件独立假设导致的语义信息丢失。于此同时,目前非自回归模型主要停留在学术验证阶段,还没有工业大数据上的相关实验与结论。' # noqa *
sentences_1 = '移动端语音唤醒模型,检测关键词为“小云小云”。模型主体为4层FSMN结构,使用CTC训练准则,参数量750K,适用于移动端设备运行。模型输入为Fbank特征,输出为基于char建模的中文全集token预测,测试工具根据每一帧的预测数据进行后处理得到输入音频的实时检测结果。模型训练采用“basetrain + finetune”的模式,basetrain过程使用大量内部移动端数据,在此基础上,使用1万条设备端录制安静场景“小云小云”数据进行微调,得到最终面向业务的模型。后续用户可在basetrain模型基础上,使用其他关键词数据进行微调,得到新的语音唤醒模型,但暂时未开放模型finetune功能。' # noqa *
eng_sentences = 'The Saint Alexander Nevsky Church was established in 1936 by Archbishop Vitaly (Maximenko) () on a tract of land donated by Yulia Martinovna Plavskaya.The initial chapel, dedicated to the memory of the great prince St. Alexander Nevsky (1220–1263), was blessed in May, 1936.The church building was subsequently expanded three times.In 1987, ground was cleared for the construction of the new church and on September 12, 1989, on the Feast Day of St. Alexander Nevsky, the cornerstone was laid and the relics of St. Herman of Alaska placed in the foundation.The imposing edifice, completed in 1997, is the work of Nikolaus Karsanov, architect and Protopresbyter Valery Lukianov, engineer.Funds were raised through donations.The Great blessing of the cathedral took place on October 18, 1997 with seven bishops, headed by Metropolitan Vitaly Ustinov, and 36 priests and deacons officiating, some 800 faithful attended the festivity.The old church was rededicated to Our Lady of Tikhvin.Metropolitan Hilarion (Kapral) announced, that cathedral will officially become the episcopal See of the Ruling Bishop of the Eastern American Diocese and the administrative center of the Diocese on September 12, 2014.At present the parish serves the spiritual needs of 300 members.The parochial school instructs over 90 boys and girls in religion, Russian language and history.The school meets every Saturday.The choir is directed by Andrew Burbelo.The sisterhood attends to the needs of the church and a church council acts in the administration of the community.The cathedral is decorated by frescoes in the Byzantine style.The iconography project was fulfilled by Father Andrew Erastov and his students from 1995 until 2001.' # noqa *

def run_pipeline(self, model_id: str, documents: str) -> Dict[str, Any]:
p = pipeline(task=Tasks.document_segmentation, model=model_id)

result = p(documents=documents)

return result

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_document(self):
logger.info('Run document segmentation with one document ...')

result = self.run_pipeline(
model_id=self.model_id, documents=self.sentences)
print(result[OutputKeys.TEXT])

result = self.run_pipeline(
model_id=self.eng_model_id, documents=self.eng_sentences)
print(result[OutputKeys.TEXT])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_documents(self):
logger.info('Run document segmentation with many documents ...')

result = self.run_pipeline(
model_id=self.model_id,
documents=[self.sentences, self.sentences_1])

documents_list = result[OutputKeys.TEXT]
for document in documents_list:
print(document)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save