@@ -2,81 +2,218 @@ | |||||
BertEmbedding的各种用法 | BertEmbedding的各种用法 | ||||
============================== | ============================== | ||||
fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。 | |||||
Bert自从在`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding <https://arxiv.org/abs/1810.04805>`_ | |||||
中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 | |||||
`中文Bert预训练 <https://github.com/ymcui/Chinese-BERT-wwm>`_ 。 | |||||
使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。 | |||||
为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 | |||||
`数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?tab=fed5xh&c=D42A0AC0>`_ 。或您可从 doc:`tutorial/tutorial_3_embedding` 与 | |||||
doc:`tutorial/tutorial_4_load_dataset` 了解更多相关信息。 | |||||
预训练好的Embedding参数及数据集的介绍和自动下载功能见 :doc:`/tutorials/tutorial_3_embedding` 和 | |||||
:doc:`/tutorials/tutorial_4_load_dataset` | |||||
---------------------------------- | |||||
中文任务 | |||||
---------------------------------- | |||||
下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。 | |||||
1. BERT for Squence Classification | |||||
1. 使用Bert进行文本分类 | |||||
---------------------------------- | ---------------------------------- | ||||
文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类 | |||||
.. code-block:: text | |||||
在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。 | |||||
1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! | |||||
这里我们使用fastNLP提供自动下载的微博分类进行测试 | |||||
.. code-block:: python | .. code-block:: python | ||||
import warnings | |||||
import torch | |||||
warnings.filterwarnings("ignore") | |||||
from fastNLP.io import WeiboSenti100kPipe | |||||
# 载入数据集 | |||||
from fastNLP.io import SSTPipe | |||||
data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file() | |||||
data_bundle | |||||
data_bundle =WeiboSenti100kPipe().process_from_file() | |||||
data_bundle.rename_field('chars', 'words') | |||||
# 载入BertEmbedding | # 载入BertEmbedding | ||||
from fastNLP.embeddings import BertEmbedding | from fastNLP.embeddings import BertEmbedding | ||||
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True) | |||||
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True) | |||||
# 载入模型 | # 载入模型 | ||||
from fastNLP.models import BertForSequenceClassification | from fastNLP.models import BertForSequenceClassification | ||||
model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target'))) | model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target'))) | ||||
# 训练模型 | # 训练模型 | ||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | ||||
trainer = Trainer(data_bundle.get_dataset('train'), model, | trainer = Trainer(data_bundle.get_dataset('train'), model, | ||||
optimizer=Adam(model_params=model.parameters(), lr=2e-5), | optimizer=Adam(model_params=model.parameters(), lr=2e-5), | ||||
loss=CrossEntropyLoss(), device=[0], | |||||
batch_size=64, dev_data=data_bundle.get_dataset('dev'), | |||||
loss=CrossEntropyLoss(), device=0, | |||||
batch_size=8, dev_data=data_bundle.get_dataset('dev'), | |||||
metrics=AccuracyMetric(), n_epochs=2, print_every=1) | metrics=AccuracyMetric(), n_epochs=2, print_every=1) | ||||
trainer.train() | trainer.train() | ||||
# 测试结果并删除模型 | |||||
# 测试结果 | |||||
from fastNLP import Tester | from fastNLP import Tester | ||||
tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric()) | tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric()) | ||||
tester.test() | tester.test() | ||||
2. BERT for Sentence Matching | |||||
----------------------------- | |||||
输出结果:: | |||||
在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。 | |||||
In Epoch:1/Step:12499, got best dev performance: | |||||
AccuracyMetric: acc=0.9838 | |||||
Reloaded the best model. | |||||
Evaluate data in 63.84 seconds! | |||||
[tester] | |||||
AccuracyMetric: acc=0.9815 | |||||
2. 使用Bert进行命名实体识别 | |||||
---------------------------------- | |||||
命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 | |||||
两句话,例如下面的例子 | |||||
.. code-block:: text | |||||
中 B-ORG | |||||
共 I-ORG | |||||
中 I-ORG | |||||
央 I-ORG | |||||
致 O | |||||
中 B-ORG | |||||
国 I-ORG | |||||
致 I-ORG | |||||
公 I-ORG | |||||
党 I-ORG | |||||
十 I-ORG | |||||
一 I-ORG | |||||
大 I-ORG | |||||
的 O | |||||
贺 O | |||||
词 O | |||||
这部分内容请参考 :doc:`快速实现序列标注模型 </tutorials/tutorial_9_seq_labeling>` | |||||
3. 使用Bert进行文本匹配 | |||||
---------------------------------- | |||||
文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 | |||||
具有相同的意思。这里我们使用 | |||||
.. code-block:: python | .. code-block:: python | ||||
# 载入数据集 | |||||
from fastNLP.io import RTEBertPipe | |||||
data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file() | |||||
data_bundle = CNXNLIBertPipe().process_from_file(paths) | |||||
data_bundle.rename_field('chars', 'words') | |||||
print(data_bundle) | |||||
# 载入BertEmbedding | # 载入BertEmbedding | ||||
from fastNLP.embeddings import BertEmbedding | from fastNLP.embeddings import BertEmbedding | ||||
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True) | |||||
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True) | |||||
# 载入模型 | # 载入模型 | ||||
from fastNLP.models import BertForSentenceMatching | from fastNLP.models import BertForSentenceMatching | ||||
model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target'))) | model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target'))) | ||||
# 训练模型 | # 训练模型 | ||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | ||||
from fastNLP.core.optimizer import AdamW | |||||
from fastNLP.core.callback import WarmupCallback | |||||
callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ] | |||||
trainer = Trainer(data_bundle.get_dataset('train'), model, | trainer = Trainer(data_bundle.get_dataset('train'), model, | ||||
optimizer=Adam(model_params=model.parameters(), lr=2e-5), | |||||
loss=CrossEntropyLoss(), device=[0], | |||||
batch_size=16, dev_data=data_bundle.get_dataset('dev'), | |||||
metrics=AccuracyMetric(), n_epochs=2, print_every=1) | |||||
optimizer=AdamW(params=model.parameters(), lr=4e-5), | |||||
loss=CrossEntropyLoss(), device=0, | |||||
batch_size=8, dev_data=data_bundle.get_dataset('dev'), | |||||
metrics=AccuracyMetric(), n_epochs=5, print_every=1, | |||||
update_every=8, callbacks=callbacks) | |||||
trainer.train() | trainer.train() | ||||
from fastNLP import Tester | |||||
tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric()) | |||||
tester.test() | |||||
运行结果:: | |||||
In Epoch:3/Step:73632, got best dev performance: | |||||
AccuracyMetric: acc=0.781928 | |||||
Reloaded the best model. | |||||
Evaluate data in 18.54 seconds! | |||||
[tester] | |||||
AccuracyMetric: acc=0.783633 | |||||
4. 使用Bert进行中文问答 | |||||
---------------------------------- | |||||
问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 | |||||
例如 | |||||
"context": "锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常 | |||||
用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及 | |||||
作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合 | |||||
相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单 | |||||
皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大 | |||||
钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师 | |||||
傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼 | |||||
和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:", | |||||
"question": "锣鼓经是什么?", | |||||
"answers": [ | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
}, | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
}, | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
} | |||||
] | |||||
您可以通过以下的代码训练`CMRC2018 <https://github.com/ymcui/cmrc2018>`_ | |||||
.. code-block:: python | |||||
from fastNLP.embeddings import BertEmbedding | |||||
from fastNLP.models import BertForQuestionAnswering | |||||
from fastNLP.core.losses import CMRC2018Loss | |||||
from fastNLP.core.metrics import CMRC2018Metric | |||||
from fastNLP.io.pipe.qa import CMRC2018BertPipe | |||||
from fastNLP import Trainer, BucketSampler | |||||
from fastNLP import WarmupCallback, GradientClipCallback | |||||
from fastNLP.core.optimizer import AdamW | |||||
data_bundle = CMRC2018BertPipe().process_from_file() | |||||
data_bundle.rename_field('chars', 'words') | |||||
print(data_bundle) | |||||
embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True, | |||||
dropout=0.5, word_dropout=0.01) | |||||
model = BertForQuestionAnswering(embed) | |||||
loss = CMRC2018Loss() | |||||
metric = CMRC2018Metric() | |||||
wm_callback = WarmupCallback(schedule='linear') | |||||
gc_callback = GradientClipCallback(clip_value=1, clip_type='norm') | |||||
callbacks = [wm_callback, gc_callback] | |||||
optimizer = AdamW(model.parameters(), lr=5e-5) | |||||
trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, | |||||
sampler=BucketSampler(seq_len_field_name='context_len'), | |||||
dev_data=data_bundle.get_dataset('dev'), metrics=metric, | |||||
callbacks=callbacks, device=0, batch_size=6, num_workers=2, n_epochs=2, print_every=1, | |||||
test_use_tqdm=False, update_every=10) | |||||
trainer.train(load_best_model=False) | |||||
训练结果(和论文中报道的基本一致):: | |||||
In Epoch:2/Step:1692, got best dev performance: | |||||
CMRC2018Metric: f1=85.61, em=66.08 | |||||
@@ -57,11 +57,12 @@ __all__ = [ | |||||
"BCELoss", | "BCELoss", | ||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"CMRC2018Loss", | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"ExtractiveQAMetric", | |||||
"CMRC2018Metric", | |||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
"Adam", | "Adam", | ||||
@@ -82,8 +83,8 @@ from .const import Const | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric | |||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -468,7 +468,8 @@ class GradientClipCallback(Callback): | |||||
if getattr(self.trainer, 'fp16', ''): | if getattr(self.trainer, 'fp16', ''): | ||||
_check_fp16() | _check_fp16() | ||||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | ||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | else: | ||||
self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
@@ -354,6 +354,9 @@ class DataSet(object): | |||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | ||||
return self.dataset.field_arrays[item][self.idx] | return self.dataset.field_arrays[item][self.idx] | ||||
def __setitem__(self, key, value): | |||||
raise TypeError("You cannot modify value directly.") | |||||
def items(self): | def items(self): | ||||
ins = self.dataset[self.idx] | ins = self.dataset[self.idx] | ||||
return ins.items() | return ins.items() | ||||
@@ -45,6 +45,9 @@ class Instance(object): | |||||
""" | """ | ||||
return self.fields.items() | return self.fields.items() | ||||
def __contains__(self, item): | |||||
return item in self.fields | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if name in self.fields: | if name in self.fields: | ||||
return self.fields[name] | return self.fields[name] | ||||
@@ -11,7 +11,10 @@ __all__ = [ | |||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"BCELoss", | "BCELoss", | ||||
"L1Loss", | "L1Loss", | ||||
"NLLLoss" | |||||
"NLLLoss", | |||||
"CMRC2018Loss" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -344,6 +347,47 @@ class LossInForward(LossBase): | |||||
return loss | return loss | ||||
class CMRC2018Loss(LossBase): | |||||
""" | |||||
用于计算CMRC2018中文问答任务。 | |||||
""" | |||||
def __init__(self, target_start=None, target_end=None, context_len=None, pred_start=None, pred_end=None, | |||||
reduction='mean'): | |||||
super().__init__() | |||||
assert reduction in ('mean', 'sum') | |||||
self._init_param_map(target_start=target_start, target_end=target_end, context_len=context_len, | |||||
pred_start=pred_start, pred_end=pred_end) | |||||
self.reduction = reduction | |||||
def get_loss(self, target_start, target_end, context_len, pred_start, pred_end): | |||||
""" | |||||
:param target_start: batch_size | |||||
:param target_end: batch_size | |||||
:param context_len: batch_size | |||||
:param pred_start: batch_size x max_len | |||||
:param pred_end: batch_size x max_len | |||||
:return: | |||||
""" | |||||
batch_size, max_len = pred_end.size() | |||||
mask = seq_len_to_mask(context_len, max_len).eq(0) | |||||
pred_start = pred_start.masked_fill(mask, float('-inf')) | |||||
pred_end = pred_end.masked_fill(mask, float('-inf')) | |||||
start_loss = F.cross_entropy(pred_start, target_start, reduction='sum') | |||||
end_loss = F.cross_entropy(pred_end, target_end, reduction='sum') | |||||
loss = start_loss + end_loss | |||||
if self.reduction == 'mean': | |||||
loss = loss / batch_size | |||||
return loss/2 | |||||
def _prepare_losser(losser): | def _prepare_losser(losser): | ||||
if losser is None: | if losser is None: | ||||
losser = LossInForward() | losser = LossInForward() | ||||
@@ -6,7 +6,7 @@ __all__ = [ | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"ExtractiveQAMetric" | |||||
"CMRC2018Metric" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -14,6 +14,7 @@ import warnings | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from typing import Union | from typing import Union | ||||
import re | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -116,7 +117,7 @@ class MetricBase(object): | |||||
self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self._param_map = {} # key is param in function, value is input param. | self._param_map = {} # key is param in function, value is input param. | ||||
self._checked = False | self._checked = False | ||||
@@ -139,7 +140,7 @@ class MetricBase(object): | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def set_metric_name(self, name:str): | |||||
def set_metric_name(self, name: str): | |||||
""" | """ | ||||
设置metric的名称,默认是Metric的class name. | 设置metric的名称,默认是Metric的class name. | ||||
@@ -156,7 +157,7 @@ class MetricBase(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
return self._metric_name | return self._metric_name | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | """检查key_map和其他参数map,并将这些映射关系添加到self._param_map | ||||
@@ -189,7 +190,7 @@ class MetricBase(object): | |||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and _param_map | # check consistence between signature and _param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
@@ -198,7 +199,7 @@ class MetricBase(object): | |||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
@@ -213,7 +214,7 @@ class MetricBase(object): | |||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
这个方法会调用self.evaluate 方法. | 这个方法会调用self.evaluate 方法. | ||||
@@ -228,12 +229,12 @@ class MetricBase(object): | |||||
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | fast_param = self._fast_param_map(pred_dict, target_dict) | ||||
if fast_param: | if fast_param: | ||||
self.evaluate(**fast_param) | self.evaluate(**fast_param) | ||||
return | return | ||||
if not self._checked: | if not self._checked: | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
@@ -243,14 +244,14 @@ class MetricBase(object): | |||||
for func_arg, input_arg in self._param_map.items(): | for func_arg, input_arg in self._param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | ||||
# 2. only part of the _param_map are passed, left are not | # 2. only part of the _param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self._param_map: | if arg not in self._param_map: | ||||
self._param_map[arg] = arg # This param does not need mapping. | self._param_map[arg] = arg # This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | ||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -259,7 +260,7 @@ class MetricBase(object): | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | ||||
if input_arg in target_dict: | if input_arg in target_dict: | ||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | mapped_target_dict[mapped_arg] = target_dict[input_arg] | ||||
# missing | # missing | ||||
if not self._checked: | if not self._checked: | ||||
duplicated = [] | duplicated = [] | ||||
@@ -274,23 +275,23 @@ class MetricBase(object): | |||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | unused=check_res.unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=check_res.required, | required=check_res.required, | ||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.evaluate)) | func_signature=_get_func_signature(self.evaluate)) | ||||
self._checked = True | self._checked = True | ||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
return | return | ||||
@@ -298,7 +299,7 @@ class AccuracyMetric(MetricBase): | |||||
""" | """ | ||||
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None): | def __init__(self, pred=None, target=None, seq_len=None): | ||||
""" | """ | ||||
@@ -306,14 +307,14 @@ class AccuracyMetric(MetricBase): | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | ||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | """ | ||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -333,28 +334,28 @@ class AccuracyMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | if seq_len is not None and not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if seq_len is not None and target.dim()>1: | |||||
if seq_len is not None and target.dim() > 1: | |||||
max_len = target.size(1) | max_len = target.size(1) | ||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | ||||
else: | else: | ||||
masks = None | masks = None | ||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
pass | pass | ||||
elif pred.dim() == target.dim() + 1: | elif pred.dim() == target.dim() + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
if seq_len is None and target.dim()>1: | |||||
if seq_len is None and target.dim() > 1: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | ||||
else: | else: | ||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target = target.to(pred) | target = target.to(pred) | ||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | ||||
@@ -362,7 +363,7 @@ class AccuracyMetric(MetricBase): | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | self.acc_count += torch.sum(torch.eq(pred, target)).item() | ||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | """ | ||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
@@ -388,7 +389,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -417,7 +418,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -479,7 +480,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bio_tag = None | prev_bio_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -497,7 +498,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | ||||
def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | |||||
""" | """ | ||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | 给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | ||||
@@ -533,7 +534,7 @@ def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | "'bio', 'bmes', 'bmeso', 'bioes' type.") | ||||
def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): | |||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||||
""" | """ | ||||
检查vocab中的tag是否与encoding_type是匹配的 | 检查vocab中的tag是否与encoding_type是匹配的 | ||||
@@ -557,7 +558,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encodi | |||||
tags = encoding_type | tags = encoding_type | ||||
for tag in tag_set: | for tag in tag_set: | ||||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | ||||
f"encoding_type." | |||||
f"encoding_type." | |||||
tags = tags.replace(tag, '') # 删除该值 | tags = tags.replace(tag, '') # 删除该值 | ||||
if tags: # 如果不为空,说明出现了未使用的tag | if tags: # 如果不为空,说明出现了未使用的tag | ||||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | ||||
@@ -589,7 +590,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
... | ... | ||||
} | } | ||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | ||||
only_gross=True, f_type='micro', beta=1): | only_gross=True, f_type='micro', beta=1): | ||||
r""" | r""" | ||||
@@ -616,7 +617,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | ||||
self.encoding_type = encoding_type | self.encoding_type = encoding_type | ||||
else: | else: | ||||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
if self.encoding_type == 'bmes': | if self.encoding_type == 'bmes': | ||||
self.tag_to_span_func = _bmes_tag_to_spans | self.tag_to_span_func = _bmes_tag_to_spans | ||||
@@ -628,22 +629,22 @@ class SpanFPreRecMetric(MetricBase): | |||||
self.tag_to_span_func = _bioes_tag_to_spans | self.tag_to_span_func = _bioes_tag_to_spans | ||||
else: | else: | ||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | ||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
self.beta = beta | self.beta = beta | ||||
self.beta_square = self.beta ** 2 | self.beta_square = self.beta ** 2 | ||||
self.only_gross = only_gross | self.only_gross = only_gross | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
def evaluate(self, pred, target, seq_len): | def evaluate(self, pred, target, seq_len): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | """evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -658,11 +659,11 @@ class SpanFPreRecMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if not isinstance(seq_len, torch.Tensor): | if not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
@@ -675,20 +676,20 @@ class SpanFPreRecMetric(MetricBase): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
batch_size = pred.size(0) | batch_size = pred.size(0) | ||||
pred = pred.tolist() | pred = pred.tolist() | ||||
target = target.tolist() | target = target.tolist() | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
pred_tags = pred[i][:int(seq_len[i])] | pred_tags = pred[i][:int(seq_len[i])] | ||||
gold_tags = target[i][:int(seq_len[i])] | gold_tags = target[i][:int(seq_len[i])] | ||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | ||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | ||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | ||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | ||||
for span in pred_spans: | for span in pred_spans: | ||||
if span in gold_spans: | if span in gold_spans: | ||||
self._true_positives[span[0]] += 1 | self._true_positives[span[0]] += 1 | ||||
@@ -697,7 +698,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
self._false_positives[span[0]] += 1 | self._false_positives[span[0]] += 1 | ||||
for span in gold_spans: | for span in gold_spans: | ||||
self._false_negatives[span[0]] += 1 | self._false_negatives[span[0]] += 1 | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | ||||
evaluate_result = {} | evaluate_result = {} | ||||
@@ -723,12 +724,12 @@ class SpanFPreRecMetric(MetricBase): | |||||
evaluate_result[f_key] = f | evaluate_result[f_key] = f | ||||
evaluate_result[pre_key] = pre | evaluate_result[pre_key] = pre | ||||
evaluate_result[rec_key] = rec | evaluate_result[rec_key] = rec | ||||
if self.f_type == 'macro': | if self.f_type == 'macro': | ||||
evaluate_result['f'] = f_sum / len(tags) | evaluate_result['f'] = f_sum / len(tags) | ||||
evaluate_result['pre'] = pre_sum / len(tags) | evaluate_result['pre'] = pre_sum / len(tags) | ||||
evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | ||||
sum(self._false_negatives.values()), | sum(self._false_negatives.values()), | ||||
@@ -736,17 +737,17 @@ class SpanFPreRecMetric(MetricBase): | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
if reset: | if reset: | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
for key, value in evaluate_result.items(): | for key, value in evaluate_result.items(): | ||||
evaluate_result[key] = round(value, 6) | evaluate_result[key] = round(value, 6) | ||||
return evaluate_result | return evaluate_result | ||||
def _compute_f_pre_rec(self, tp, fn, fp): | def _compute_f_pre_rec(self, tp, fn, fp): | ||||
""" | """ | ||||
@@ -758,7 +759,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
pre = tp / (fp + tp + 1e-13) | pre = tp / (fp + tp + 1e-13) | ||||
rec = tp / (fn + tp + 1e-13) | rec = tp / (fn + tp + 1e-13) | ||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | ||||
return f, pre, rec | return f, pre, rec | ||||
@@ -827,168 +828,129 @@ def _pred_topk(y_prob, k=1): | |||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
class ExtractiveQAMetric(MetricBase): | |||||
r""" | |||||
抽取式QA(如SQuAD)的metric. | |||||
""" | |||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | |||||
beta=1, right_open=True, print_predict_stat=False): | |||||
r""" | |||||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||||
class CMRC2018Metric(MetricBase): | |||||
def __init__(self, answers=None, raw_chars=None, context_len=None, pred_start=None, pred_end=None): | |||||
super().__init__() | |||||
self._init_param_map(answers=answers, raw_chars=raw_chars, context_len=context_len, pred_start=pred_start, | |||||
pred_end=pred_end) | |||||
self.em = 0 | |||||
self.total = 0 | |||||
self.f1 = 0 | |||||
def evaluate(self, answers, raw_chars, context_len, pred_start, pred_end): | |||||
""" | """ | ||||
super(ExtractiveQAMetric, self).__init__() | |||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | |||||
self.print_predict_stat = print_predict_stat | |||||
self.no_ans_correct = 0 | |||||
self.no_ans_wrong = 0 | |||||
self.has_ans_correct = 0 | |||||
self.has_ans_wrong = 0 | |||||
self.has_ans_f = 0. | |||||
self.no2no = 0 | |||||
self.no2yes = 0 | |||||
self.yes2no = 0 | |||||
self.yes2yes = 0 | |||||
self.f_beta = beta | |||||
self.right_open = right_open | |||||
def evaluate(self, pred1, pred2, target1, target2): | |||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred1: [batch]或者[batch, seq_len], 预测答案开始的index, 如果SQuAD2.0中答案为空则为0 | |||||
:param pred2: [batch]或者[batch, seq_len] 预测答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) | |||||
:param target1: [batch], 正确答案开始的index, 如果SQuAD2.0中答案为空则为0 | |||||
:param target2: [batch], 正确答案结束的index, 如果SQuAD2.0中答案为空则为0(左闭右闭区间)或者1(左闭右开区间) | |||||
:return: None | |||||
:param list[str] answers: 如[["答案1", "答案2", "答案3"], [...], ...] | |||||
:param list[str] raw_chars: [["这", "是", ...], [...]] | |||||
:param tensor context_len: context长度, batch_size | |||||
:param tensor pred_start: batch_size x length | |||||
:param tensor pred_end: batch_size x length | |||||
:return: | |||||
""" | """ | ||||
pred_start = pred1 | |||||
pred_end = pred2 | |||||
target_start = target1 | |||||
target_end = target2 | |||||
if len(pred_start.size()) == 2: | |||||
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | |||||
else: | |||||
start_inference = pred_start.cpu().tolist() | |||||
if len(pred_end.size()) == 2: | |||||
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | |||||
else: | |||||
end_inference = pred_end.cpu().tolist() | |||||
start, end = [], [] | |||||
max_len = pred_start.size(1) | |||||
t_start = target_start.cpu().tolist() | |||||
t_end = target_end.cpu().tolist() | |||||
for s, e in zip(start_inference, end_inference): | |||||
start.append(min(s, e)) | |||||
end.append(max(s, e)) | |||||
for s, e, ts, te in zip(start, end, t_start, t_end): | |||||
if not self.right_open: | |||||
e += 1 | |||||
te += 1 | |||||
if ts == 0 and te == 1: | |||||
if s == 0 and e == 1: | |||||
self.no_ans_correct += 1 | |||||
self.no2no += 1 | |||||
else: | |||||
self.no_ans_wrong += 1 | |||||
self.no2yes += 1 | |||||
else: | |||||
if s == 0 and e == int(not self.right_open): | |||||
self.yes2no += 1 | |||||
else: | |||||
self.yes2yes += 1 | |||||
if s == ts and e == te: | |||||
self.has_ans_correct += 1 | |||||
else: | |||||
self.has_ans_wrong += 1 | |||||
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | |||||
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | |||||
a, b = torch.tensor(a), torch.tensor(b) | |||||
TP = int(torch.sum(a * b)) | |||||
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | |||||
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | |||||
if pre + rec > 0: | |||||
f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec) | |||||
else: | |||||
f = 0 | |||||
self.has_ans_f += f | |||||
batch_size, max_len = pred_start.size() | |||||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(0) | |||||
pred_start.masked_fill_(context_mask, float('-inf')) | |||||
pred_end.masked_fill_(context_mask, float('-inf')) | |||||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | |||||
pred_start_mask = pred_start.eq(max_pred_start).cumsum(dim=-1).eq(0) # 只能预测这之后的值 | |||||
pred_end.masked_fill_(pred_start_mask, float('-inf')) | |||||
pred_end_index = pred_end.argmax(dim=-1) + 1 | |||||
pred_ans = [] | |||||
for index, (start, end) in enumerate(zip(pred_start_index.flatten().tolist(), pred_end_index.tolist())): | |||||
pred_ans.append(''.join(raw_chars[index][start:end])) | |||||
for answer, pred_an in zip(answers, pred_ans): | |||||
pred_an = pred_an.strip() | |||||
self.f1 += _calc_cmrc2018_f1_score(answer, pred_an) | |||||
self.total += 1 | |||||
self.em += _calc_cmrc2018_em_score(answer, pred_an) | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||||
evaluate_result = {} | |||||
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | |||||
return evaluate_result | |||||
evaluate_result['EM'] = 0 | |||||
evaluate_result[f'f_{self.f_beta}'] = 0 | |||||
flag = 0 | |||||
if self.no_ans_correct + self.no_ans_wrong > 0: | |||||
evaluate_result[f'noAns-f_{self.f_beta}'] = \ | |||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | |||||
evaluate_result['noAns-EM'] = \ | |||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | |||||
evaluate_result['EM'] += evaluate_result['noAns-EM'] | |||||
flag += 1 | |||||
if self.has_ans_correct + self.has_ans_wrong > 0: | |||||
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | |||||
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | |||||
evaluate_result['hasAns-EM'] = \ | |||||
round(100 * self.has_ans_correct / (self.has_ans_correct + self.has_ans_wrong), 3) | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | |||||
evaluate_result['EM'] += evaluate_result['hasAns-EM'] | |||||
flag += 1 | |||||
if self.print_predict_stat: | |||||
evaluate_result['no2no'] = self.no2no | |||||
evaluate_result['no2yes'] = self.no2yes | |||||
evaluate_result['yes2no'] = self.yes2no | |||||
evaluate_result['yes2yes'] = self.yes2yes | |||||
if flag <= 0: | |||||
return evaluate_result | |||||
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | |||||
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | |||||
eval_res = {'f1': round(self.f1 / self.total*100, 2), 'em': round(self.em / self.total*100, 2)} | |||||
if reset: | if reset: | ||||
self.no_ans_correct = 0 | |||||
self.no_ans_wrong = 0 | |||||
self.has_ans_correct = 0 | |||||
self.has_ans_wrong = 0 | |||||
self.has_ans_f = 0. | |||||
self.no2no = 0 | |||||
self.no2yes = 0 | |||||
self.yes2no = 0 | |||||
self.yes2yes = 0 | |||||
return evaluate_result | |||||
self.em = 0 | |||||
self.total = 0 | |||||
self.f1 = 0 | |||||
return eval_res | |||||
# split Chinese | |||||
def _cn_segmentation(in_str, rm_punc=False): | |||||
in_str = str(in_str).lower().strip() | |||||
segs_out = [] | |||||
temp_str = "" | |||||
sp_char = {'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', | |||||
'》', '……', '·', '、', '「', '」', '(', ')', '-', '~', '『', '』'} | |||||
for char in in_str: | |||||
if rm_punc and char in sp_char: | |||||
continue | |||||
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: | |||||
if temp_str != "": | |||||
ss = list(temp_str) | |||||
segs_out.extend(ss) | |||||
temp_str = "" | |||||
segs_out.append(char) | |||||
else: | |||||
temp_str += char | |||||
# handling last part | |||||
if temp_str != "": | |||||
ss = list(temp_str) | |||||
segs_out.extend(ss) | |||||
return segs_out | |||||
# remove punctuation | |||||
def _remove_punctuation(in_str): | |||||
in_str = str(in_str).lower().strip() | |||||
sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', | |||||
',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', | |||||
'「', '」', '(', ')', '-', '~', '『', '』'] | |||||
out_segs = [] | |||||
for char in in_str: | |||||
if char in sp_char: | |||||
continue | |||||
else: | |||||
out_segs.append(char) | |||||
return ''.join(out_segs) | |||||
# find longest common string | |||||
def _find_lcs(s1, s2): | |||||
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] | |||||
mmax = 0 | |||||
p = 0 | |||||
for i in range(len(s1)): | |||||
for j in range(len(s2)): | |||||
if s1[i] == s2[j]: | |||||
m[i + 1][j + 1] = m[i][j] + 1 | |||||
if m[i + 1][j + 1] > mmax: | |||||
mmax = m[i + 1][j + 1] | |||||
p = i + 1 | |||||
return s1[p - mmax:p], mmax | |||||
def _calc_cmrc2018_f1_score(answers, prediction): | |||||
f1_scores = [] | |||||
for ans in answers: | |||||
ans_segs = _cn_segmentation(ans, rm_punc=True) | |||||
prediction_segs = _cn_segmentation(prediction, rm_punc=True) | |||||
lcs, lcs_len = _find_lcs(ans_segs, prediction_segs) | |||||
if lcs_len == 0: | |||||
f1_scores.append(0) | |||||
continue | |||||
precision = 1.0 * lcs_len / len(prediction_segs) | |||||
recall = 1.0 * lcs_len / len(ans_segs) | |||||
f1 = (2 * precision * recall) / (precision + recall) | |||||
f1_scores.append(f1) | |||||
return max(f1_scores) | |||||
def _calc_cmrc2018_em_score(answers, prediction): | |||||
em = 0 | |||||
for ans in answers: | |||||
ans_ = _remove_punctuation(ans) | |||||
prediction_ = _remove_punctuation(prediction) | |||||
if ans_ == prediction_: | |||||
em = 1 | |||||
break | |||||
return em |
@@ -51,6 +51,8 @@ __all__ = [ | |||||
"BQCorpusLoader", | "BQCorpusLoader", | ||||
"LCQMCLoader", | "LCQMCLoader", | ||||
"CMRC2018Loader", | |||||
"Pipe", | "Pipe", | ||||
"YelpFullPipe", | "YelpFullPipe", | ||||
@@ -113,6 +115,8 @@ __all__ = [ | |||||
"GranularizePipe", | "GranularizePipe", | ||||
"MachingTruncatePipe", | "MachingTruncatePipe", | ||||
"CMRC2018BertPipe", | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
@@ -118,6 +118,9 @@ DATASET_DIR = { | |||||
# Summarization, English | # Summarization, English | ||||
"ext-cnndm": "ext-cnndm.zip", | "ext-cnndm": "ext-cnndm.zip", | ||||
# Question & answer | |||||
"cmrc2018": "cmrc2018.zip" | |||||
} | } | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | ||||
@@ -80,7 +80,9 @@ __all__ = [ | |||||
"BQCorpusLoader", | "BQCorpusLoader", | ||||
"LCQMCLoader", | "LCQMCLoader", | ||||
"CoReferenceLoader" | |||||
"CoReferenceLoader", | |||||
"CMRC2018Loader" | |||||
] | ] | ||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \ | from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \ | ||||
ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader | ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader | ||||
@@ -93,3 +95,5 @@ from .json import JsonLoader | |||||
from .loader import Loader | from .loader import Loader | ||||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ | ||||
LCQMCLoader | LCQMCLoader | ||||
from .qa import CMRC2018Loader | |||||
@@ -0,0 +1,74 @@ | |||||
""" | |||||
该文件中的Loader主要用于读取问答式任务的数据 | |||||
""" | |||||
from . import Loader | |||||
import json | |||||
from ...core import DataSet, Instance | |||||
__all__ = ['CMRC2018Loader'] | |||||
class CMRC2018Loader(Loader): | |||||
""" | |||||
请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 | |||||
读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 | |||||
.. csv-table:: | |||||
:header:"title", "context", "question", "answers", "answer_starts", "id" | |||||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" | |||||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" | |||||
"...", "...", "...","...", ".", "..." | |||||
其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 | |||||
验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) | |||||
.. csv-table:: | |||||
:header:"title", "context", "question", "answers", "answer_starts", "id" | |||||
"战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", ["光荣和ω-force", "光荣和ω-force", "光荣和ω-force"], ["30", "30", "30"], "DEV_0_QUERY_0" | |||||
"战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", ["村雨城", "村雨城", "任天堂游戏谜之村雨城"], ["226", "226", "219"], "DEV_0_QUERY_1" | |||||
"...", "...", "...","...", ".", "..." | |||||
其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 | |||||
英文和数字都直接按照character计量的。 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path: str) -> DataSet: | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
data = json.load(f)['data'] | |||||
ds = DataSet() | |||||
for entry in data: | |||||
title = entry['title'] | |||||
para = entry['paragraphs'][0] | |||||
context = para['context'] | |||||
qas = para['qas'] | |||||
for qa in qas: | |||||
question = qa['question'] | |||||
ans = qa['answers'] | |||||
answers = [] | |||||
answer_starts = [] | |||||
id = qa['id'] | |||||
for an in ans: | |||||
answers.append(an['text']) | |||||
answer_starts.append(an['answer_start']) | |||||
ds.append(Instance(title=title, context=context, question=question, answers=answers, | |||||
answer_starts=answer_starts,id=id)) | |||||
return ds | |||||
def download(self) -> str: | |||||
""" | |||||
如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('cmrc2018') | |||||
return output_dir | |||||
@@ -50,7 +50,9 @@ __all__ = [ | |||||
"GranularizePipe", | "GranularizePipe", | ||||
"MachingTruncatePipe", | "MachingTruncatePipe", | ||||
"CoReferencePipe" | |||||
"CoReferencePipe", | |||||
"CMRC2018BertPipe" | |||||
] | ] | ||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ | from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ | ||||
@@ -63,3 +65,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe | |||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ | ||||
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe | LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .qa import CMRC2018BertPipe |
@@ -0,0 +1,142 @@ | |||||
""" | |||||
本文件中的Pipe主要用于处理问答任务的数据。 | |||||
""" | |||||
from copy import deepcopy | |||||
from .pipe import Pipe | |||||
from .. import DataBundle | |||||
from ..loader.qa import CMRC2018Loader | |||||
from .utils import get_tokenizer | |||||
from ...core import DataSet | |||||
from ...core import Vocabulary | |||||
__all__ = ['CMRC2018BertPipe'] | |||||
def _concat_clip(data_bundle, tokenizer, max_len, concat_field_name='raw_chars'): | |||||
""" | |||||
处理data_bundle中的DataSet,将context与question进行tokenize,然后使用[SEP]将两者连接起来。 | |||||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | |||||
与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | |||||
:param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | |||||
:return: | |||||
""" | |||||
for name in list(data_bundle.datasets.keys()): | |||||
ds = data_bundle.get_dataset(name) | |||||
data_bundle.delete_dataset(name) | |||||
new_ds = DataSet() | |||||
for ins in ds: | |||||
new_ins = deepcopy(ins) | |||||
context = ins['context'] | |||||
question = ins['question'] | |||||
cnt_lst = tokenizer(context) | |||||
q_lst = tokenizer(question) | |||||
answer_start = -1 | |||||
if len(cnt_lst) + len(q_lst) + 3 > max_len: # 预留开头的[CLS]和[SEP]和中间的[sep] | |||||
if 'answer_starts' in ins and 'answers' in ins: | |||||
answer_start = int(ins['answer_starts'][0]) | |||||
answer = ins['answers'][0] | |||||
answer_end = answer_start + len(answer) | |||||
if answer_end > max_len - 3 - len(q_lst): | |||||
span_start = answer_end + 3 + len(q_lst) - max_len | |||||
span_end = answer_end | |||||
else: | |||||
span_start = 0 | |||||
span_end = max_len - 3 - len(q_lst) | |||||
cnt_lst = cnt_lst[span_start:span_end] | |||||
answer_start = int(ins['answer_starts'][0]) | |||||
answer_start -= span_start | |||||
answer_end = answer_start + len(ins['answers'][0]) | |||||
else: | |||||
cnt_lst = cnt_lst[:max_len - len(q_lst) - 3] | |||||
else: | |||||
if 'answer_starts' in ins and 'answers' in ins: | |||||
answer_start = int(ins['answer_starts'][0]) | |||||
answer_end = answer_start + len(ins['answers'][0]) | |||||
tokens = cnt_lst + ['[SEP]'] + q_lst | |||||
new_ins['context_len'] = len(cnt_lst) | |||||
new_ins[concat_field_name] = tokens | |||||
if answer_start != -1: | |||||
new_ins['target_start'] = answer_start | |||||
new_ins['target_end'] = answer_end - 1 | |||||
new_ds.append(new_ins) | |||||
data_bundle.set_dataset(new_ds, name) | |||||
return data_bundle | |||||
class CMRC2018BertPipe(Pipe): | |||||
""" | |||||
处理之后的DataSet将新增以下的field(传入的field仍然保留) | |||||
.. csv-table:: | |||||
:header: "context_len", "raw_chars", "target_start", "target_end", "chars" | |||||
492, ['范', '廷', '颂... ], 30, 34, [21, 25, ...] | |||||
491, ['范', '廷', '颂... ], 41, 61, [21, 25, ...] | |||||
".", "...", "...","...", "..." | |||||
raw_words列是context与question拼起来的结果,words是转为index的值, target_start当当前位置为答案的开头时为1,target_end当当前 | |||||
位置为答案的结尾是为1;context_len指示的是words列中context的长度。 | |||||
其中各列的meta信息如下: | |||||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | |||||
| field_names | context_len | raw_chars | target_start | target_end | chars | answers | | |||||
+-------------+-------------+-----------+--------------+------------+-------+---------| | |||||
| is_input | False | False | False | False | True | False | | |||||
| is_target | True | True | True | True | False | True | | |||||
| ignore_type | False | True | False | False | False | True | | |||||
| pad_value | 0 | 0 | 0 | 0 | 0 | 0 | | |||||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | |||||
""" | |||||
def __init__(self, max_len=510): | |||||
super().__init__() | |||||
self.max_len = max_len | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
""" | |||||
传入的DataSet应该具备以下的field | |||||
.. csv-table:: | |||||
:header:"title", "context", "question", "answers", "answer_starts", "id" | |||||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" | |||||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" | |||||
"...", "...", "...","...", ".", "..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
_tokenizer = get_tokenizer('cn-char', lang='cn') | |||||
data_bundle = _concat_clip(data_bundle, tokenizer=_tokenizer, max_len=self.max_len, concat_field_name='raw_chars') | |||||
src_vocab = Vocabulary() | |||||
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||||
field_name='raw_chars', | |||||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||||
if 'train' not in name] | |||||
) | |||||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name='raw_chars', new_field_name='chars') | |||||
data_bundle.set_vocab(src_vocab, 'chars') | |||||
data_bundle.set_ignore_type('raw_chars', 'answers', flag=True) | |||||
data_bundle.set_input('chars') | |||||
data_bundle.set_target('raw_chars', 'answers', 'target_start', 'target_end', 'context_len') | |||||
return data_bundle | |||||
def process_from_file(self, paths=None) -> DataBundle: | |||||
data_bundle = CMRC2018Loader().load(paths) | |||||
return self.process(data_bundle) |
@@ -231,10 +231,10 @@ class BertForTokenClassification(BaseModel): | |||||
class BertForQuestionAnswering(BaseModel): | class BertForQuestionAnswering(BaseModel): | ||||
""" | """ | ||||
BERT model for classification. | |||||
用于做Q&A的Bert模型,如果是Squad2.0请将BertEmbedding的include_cls_sep设置为True,Squad1.0或CMRC则设置为False | |||||
""" | """ | ||||
def __init__(self, embed: BertEmbedding, num_labels=2): | |||||
def __init__(self, embed: BertEmbedding): | |||||
""" | """ | ||||
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | :param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder). | ||||
@@ -243,15 +243,7 @@ class BertForQuestionAnswering(BaseModel): | |||||
super(BertForQuestionAnswering, self).__init__() | super(BertForQuestionAnswering, self).__init__() | ||||
self.bert = embed | self.bert = embed | ||||
self.num_labels = num_labels | |||||
self.qa_outputs = nn.Linear(self.bert.embedding_dim, self.num_labels) | |||||
if not self.bert.model.include_cls_sep: | |||||
self.bert.model.include_cls_sep = True | |||||
warn_msg = "Bert for question answering excepts BertEmbedding `include_cls_sep` True, " \ | |||||
"but got False. FastNLP has changed it to True." | |||||
logger.warning(warn_msg) | |||||
warnings.warn(warn_msg) | |||||
self.qa_outputs = nn.Linear(self.bert.embedding_dim, 2) | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -261,12 +253,7 @@ class BertForQuestionAnswering(BaseModel): | |||||
sequence_output = self.bert(words) | sequence_output = self.bert(words) | ||||
logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | ||||
return {Const.OUTPUTS(i): logits[:, :, i] for i in range(self.num_labels)} | |||||
return {'pred_start': logits[:, :, 0], 'pred_end': logits[:, :, 1]} | |||||
def predict(self, words): | def predict(self, words): | ||||
""" | |||||
:param torch.LongTensor words: [batch_size, seq_len] | |||||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size] | |||||
""" | |||||
logits = self.forward(words) | |||||
return {Const.OUTPUTS(i): torch.argmax(logits[Const.OUTPUTS(i)], dim=-1) for i in range(self.num_labels)} | |||||
return self.forward(words) |
@@ -135,6 +135,14 @@ class TestDataSetMethods(unittest.TestCase): | |||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ||||
# expect no exception raised | # expect no exception raised | ||||
def test_apply_cannot_modify_instance(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
def modify_inplace(instance): | |||||
instance['words'] = 1 | |||||
with self.assertRaises(TypeError): | |||||
ds.apply(modify_inplace) | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | ||||
@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric | |||||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | from fastNLP.core.metrics import _pred_topk, _accuracy_topk | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric | |||||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric | |||||
def _generate_tags(encoding_type, number_labels=4): | def _generate_tags(encoding_type, number_labels=4): | ||||
@@ -413,6 +413,29 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
vocab = Vocabulary().add_word_lst(list('bmes')) | vocab = Vocabulary().add_word_lst(list('bmes')) | ||||
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | metric = SpanFPreRecMetric(vocab, encoding_type='bmeso') | ||||
class TestCMRC2018Metric(unittest.TestCase): | |||||
def test_case1(self): | |||||
# 测试能否正确计算 | |||||
import torch | |||||
metric = CMRC2018Metric() | |||||
raw_chars = [list("abcsdef"), list("123456s789")] | |||||
context_len = torch.LongTensor([3, 6]) | |||||
answers = [["abc", "abc", "abc"], ["12", "12", "12"]] | |||||
pred_start = torch.randn(2, max(map(len, raw_chars))) | |||||
pred_end = torch.randn(2, max(map(len, raw_chars))) | |||||
pred_start[0, 0] = 1000 # 正好是abc | |||||
pred_end[0, 2] = 1000 | |||||
pred_start[1, 1] = 1000 # 取出234 | |||||
pred_end[1, 3] = 1000 | |||||
metric.evaluate(answers, raw_chars, context_len, pred_start, pred_end) | |||||
eval_res = metric.get_metric() | |||||
self.assertDictEqual(eval_res, {'f1': 70.0, 'em': 50.0}) | |||||
class TestUsefulFunctions(unittest.TestCase): | class TestUsefulFunctions(unittest.TestCase): | ||||
# 测试metrics.py中一些看上去挺有用的函数 | # 测试metrics.py中一些看上去挺有用的函数 | ||||
def test_case_1(self): | def test_case_1(self): | ||||
@@ -423,44 +446,4 @@ class TestUsefulFunctions(unittest.TestCase): | |||||
# 跑通即可 | # 跑通即可 | ||||
class TestExtractiveQAMetric(unittest.TestCase): | |||||
def test_cast_1(self): | |||||
qa_prediction = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||||
-0.3782, 0.8240], | |||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, -1.1563, | |||||
-0.3562, -1.4116], | |||||
[-1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||||
-2.0023, 0.0075], | |||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||||
0.3832, -0.1540], | |||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||||
-1.3508, -0.9513], | |||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||||
-0.0842, -0.4294]], | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
3.5837, 1.0184], | |||||
[1.6495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]]]) | |||||
qa_prediction = qa_prediction.permute(1, 2, 0) | |||||
pred1, pred2 = qa_prediction.split(1, dim=-1) | |||||
pred1 = pred1.squeeze(-1) | |||||
pred2 = pred2.squeeze(-1) | |||||
target1 = torch.LongTensor([3, 0, 2, 4, 4, 0]) | |||||
target2 = torch.LongTensor([4, 1, 6, 8, 7, 1]) | |||||
metric = ExtractiveQAMetric() | |||||
metric.evaluate(pred1, pred2, target1, target2) | |||||
result = metric.get_metric() | |||||
truth = {'EM': 62.5, 'f_1': 72.5, 'noAns-f_1': 50.0, 'noAns-EM': 50.0, 'hasAns-f_1': 95.0, 'hasAns-EM': 75.0} | |||||
for k, v in truth.items(): | |||||
self.assertTrue(k in result) | |||||
self.assertEqual(v, result[k]) | |||||
@@ -0,0 +1,155 @@ | |||||
{ | |||||
"version": "v1.0", | |||||
"data": [ | |||||
{ | |||||
"paragraphs": [ | |||||
{ | |||||
"id": "DEV_0", | |||||
"context": "《战国无双3》()是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容,村雨城模式剔除,战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品", | |||||
"qas": [ | |||||
{ | |||||
"question": "《战国无双3》是由哪两个公司合作开发的?", | |||||
"id": "DEV_0_QUERY_0", | |||||
"answers": [ | |||||
{ | |||||
"text": "光荣和ω-force", | |||||
"answer_start": 11 | |||||
}, | |||||
{ | |||||
"text": "光荣和ω-force", | |||||
"answer_start": 11 | |||||
}, | |||||
{ | |||||
"text": "光荣和ω-force", | |||||
"answer_start": 11 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "男女主角亦有专属声优这一模式是由谁改编的?", | |||||
"id": "DEV_0_QUERY_1", | |||||
"answers": [ | |||||
{ | |||||
"text": "村雨城", | |||||
"answer_start": 226 | |||||
}, | |||||
{ | |||||
"text": "村雨城", | |||||
"answer_start": 226 | |||||
}, | |||||
{ | |||||
"text": "任天堂游戏谜之村雨城", | |||||
"answer_start": 219 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "战国史模式主打哪两个模式?", | |||||
"id": "DEV_0_QUERY_2", | |||||
"answers": [ | |||||
{ | |||||
"text": "「战史演武」&「争霸演武」", | |||||
"answer_start": 395 | |||||
}, | |||||
{ | |||||
"text": "「战史演武」&「争霸演武」", | |||||
"answer_start": 395 | |||||
}, | |||||
{ | |||||
"text": "「战史演武」&「争霸演武」", | |||||
"answer_start": 395 | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} | |||||
], | |||||
"id": "DEV_0", | |||||
"title": "战国无双3" | |||||
}, | |||||
{ | |||||
"paragraphs": [ | |||||
{ | |||||
"id": "DEV_1", | |||||
"context": "锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:", | |||||
"qas": [ | |||||
{ | |||||
"question": "锣鼓经是什么?", | |||||
"id": "DEV_1_QUERY_0", | |||||
"answers": [ | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
}, | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
}, | |||||
{ | |||||
"text": "大陆传统器乐及戏曲里面常用的打击乐记谱方法", | |||||
"answer_start": 4 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "锣鼓经常用的节奏型称为什么?", | |||||
"id": "DEV_1_QUERY_1", | |||||
"answers": [ | |||||
{ | |||||
"text": "锣鼓点", | |||||
"answer_start": 67 | |||||
}, | |||||
{ | |||||
"text": "锣鼓点", | |||||
"answer_start": 67 | |||||
}, | |||||
{ | |||||
"text": "锣鼓点", | |||||
"answer_start": 67 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "锣鼓经运用的程式是什么?", | |||||
"id": "DEV_1_QUERY_2", | |||||
"answers": [ | |||||
{ | |||||
"text": "依照角色行当的身份、性格、情绪以及环境,配合相应的锣鼓点。", | |||||
"answer_start": 167 | |||||
}, | |||||
{ | |||||
"text": "依照角色行当的身份、性格、情绪以及环境,配合相应的锣鼓点。", | |||||
"answer_start": 167 | |||||
}, | |||||
{ | |||||
"text": "依照角色行当的身份、性格、情绪以及环境,配合相应的锣鼓点", | |||||
"answer_start": 167 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "戏曲锣鼓所运用的敲击乐器主要有什么类型?", | |||||
"id": "DEV_1_QUERY_3", | |||||
"answers": [ | |||||
{ | |||||
"text": "鼓、锣、钹和板", | |||||
"answer_start": 237 | |||||
}, | |||||
{ | |||||
"text": "鼓、锣、钹和板", | |||||
"answer_start": 237 | |||||
}, | |||||
{ | |||||
"text": "鼓、锣、钹和板", | |||||
"answer_start": 237 | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} | |||||
], | |||||
"id": "DEV_1", | |||||
"title": "锣鼓经" | |||||
} | |||||
] | |||||
} |
@@ -0,0 +1,161 @@ | |||||
{ | |||||
"version": "v1.0", | |||||
"data": [ | |||||
{ | |||||
"paragraphs": [ | |||||
{ | |||||
"id": "TRAIN_186", | |||||
"context": "范廷颂枢机(,),圣名保禄·若瑟(),是越南罗马天主教枢机。1963年被任为主教;1990年被擢升为天主教河内总教区宗座署理;1994年被擢升为总主教,同年年底被擢升为枢机;2009年2月离世。范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生;童年时接受良好教育后,被一位越南神父带到河内继续其学业。范廷颂于1940年在河内大修道院完成神学学业。范廷颂于1949年6月6日在河内的主教座堂晋铎;及后被派到圣女小德兰孤儿院服务。1950年代,范廷颂在河内堂区创建移民接待中心以收容到河内避战的难民。1954年,法越战争结束,越南民主共和国建都河内,当时很多天主教神职人员逃至越南的南方,但范廷颂仍然留在河内。翌年管理圣若望小修院;惟在1960年因捍卫修院的自由、自治及拒绝政府在修院设政治课的要求而被捕。1963年4月5日,教宗任命范廷颂为天主教北宁教区主教,同年8月15日就任;其牧铭为「我信天主的爱」。由于范廷颂被越南政府软禁差不多30年,因此他无法到所属堂区进行牧灵工作而专注研读等工作。范廷颂除了面对战争、贫困、被当局迫害天主教会等问题外,也秘密恢复修院、创建女修会团体等。1990年,教宗若望保禄二世在同年6月18日擢升范廷颂为天主教河内总教区宗座署理以填补该教区总主教的空缺。1994年3月23日,范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理;同年11月26日,若望保禄二世擢升范廷颂为枢机。范廷颂在1995年至2001年期间出任天主教越南主教团主席。2003年4月26日,教宗若望保禄二世任命天主教谅山教区兼天主教高平教区吴光杰主教为天主教河内总教区署理主教;及至2005年2月19日,范廷颂因获批辞去总主教职务而荣休;吴光杰同日真除天主教河内总教区总主教职务。范廷颂于2009年2月22日清晨在河内离世,享年89岁;其葬礼于同月26日上午在天主教河内总教区总主教座堂举行。", | |||||
"qas": [ | |||||
{ | |||||
"question": "范廷颂是什么时候被任为主教的?", | |||||
"id": "TRAIN_186_QUERY_0", | |||||
"answers": [ | |||||
{ | |||||
"text": "1963年", | |||||
"answer_start": 30 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "1990年,范廷颂担任什么职务?", | |||||
"id": "TRAIN_186_QUERY_1", | |||||
"answers": [ | |||||
{ | |||||
"text": "1990年被擢升为天主教河内总教区宗座署理", | |||||
"answer_start": 41 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "范廷颂是于何时何地出生的?", | |||||
"id": "TRAIN_186_QUERY_2", | |||||
"answers": [ | |||||
{ | |||||
"text": "范廷颂于1919年6月15日在越南宁平省天主教发艳教区出生", | |||||
"answer_start": 97 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "1994年3月,范廷颂担任什么职务?", | |||||
"id": "TRAIN_186_QUERY_3", | |||||
"answers": [ | |||||
{ | |||||
"text": "1994年3月23日,范廷颂被教宗若望保禄二世擢升为天主教河内总教区总主教并兼天主教谅山教区宗座署理", | |||||
"answer_start": 548 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "范廷颂是何时去世的?", | |||||
"id": "TRAIN_186_QUERY_4", | |||||
"answers": [ | |||||
{ | |||||
"text": "范廷颂于2009年2月22日清晨在河内离世", | |||||
"answer_start": 759 | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} | |||||
], | |||||
"id": "TRAIN_186", | |||||
"title": "范廷颂" | |||||
}, | |||||
{ | |||||
"paragraphs": [ | |||||
{ | |||||
"id": "TRAIN_54", | |||||
"context": "安雅·罗素法(,),来自俄罗斯圣彼得堡的模特儿。她是《全美超级模特儿新秀大赛》第十季的亚军。2008年,安雅宣布改回出生时的名字:安雅·罗素法(Anya Rozova),在此之前是使用安雅·冈()。安雅于俄罗斯出生,后来被一个居住在美国夏威夷群岛欧胡岛檀香山的家庭领养。安雅十七岁时曾参与香奈儿、路易·威登及芬迪(Fendi)等品牌的非正式时装秀。2007年,她于瓦伊帕胡高级中学毕业。毕业后,她当了一名售货员。她曾为Russell Tanoue拍摄照片,Russell Tanoue称赞她是「有前途的新面孔」。安雅在半准决赛面试时说她对模特儿行业充满热诚,所以参加全美超级模特儿新秀大赛。她于比赛中表现出色,曾五次首名入围,平均入围顺序更拿下历届以来最优异的成绩(2.64),另外胜出三次小挑战,分别获得与评判尼祖·百克拍照、为柠檬味道的七喜拍摄广告的机会及十万美元、和盖马蒂洛(Gai Mattiolo)设计的晚装。在最后两强中,安雅与另一名参赛者惠妮·汤姆森为范思哲走秀,但评判认为她在台上不够惠妮突出,所以选了惠妮当冠军,安雅屈居亚军(但就整体表现来说,部份网友认为安雅才是第十季名副其实的冠军。)安雅在比赛拿五次第一,也胜出多次小挑战。安雅赛后再次与Russell Tanoue合作,为2008年4月30日出版的MidWeek杂志拍摄封面及内页照。其后她参加了V杂志与Supreme模特儿公司合办的模特儿选拔赛2008。她其后更与Elite签约。最近她与香港的模特儿公司 Style International Management 签约,并在香港发展其模特儿事业。她曾在很多香港的时装杂志中任模特儿,《Jet》、《东方日报》、《Elle》等。", | |||||
"qas": [ | |||||
{ | |||||
"question": "安雅·罗素法参加了什么比赛获得了亚军?", | |||||
"id": "TRAIN_54_QUERY_0", | |||||
"answers": [ | |||||
{ | |||||
"text": "《全美超级模特儿新秀大赛》第十季", | |||||
"answer_start": 26 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "Russell Tanoue对安雅·罗素法的评价是什么?", | |||||
"id": "TRAIN_54_QUERY_1", | |||||
"answers": [ | |||||
{ | |||||
"text": "有前途的新面孔", | |||||
"answer_start": 247 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "安雅·罗素法合作过的香港杂志有哪些?", | |||||
"id": "TRAIN_54_QUERY_2", | |||||
"answers": [ | |||||
{ | |||||
"text": "《Jet》、《东方日报》、《Elle》等", | |||||
"answer_start": 706 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "毕业后的安雅·罗素法职业是什么?", | |||||
"id": "TRAIN_54_QUERY_3", | |||||
"answers": [ | |||||
{ | |||||
"text": "售货员", | |||||
"answer_start": 202 | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} | |||||
], | |||||
"id": "TRAIN_54", | |||||
"title": "安雅·罗素法" | |||||
}, | |||||
{ | |||||
"paragraphs": [ | |||||
{ | |||||
"id": "TRAIN_756", | |||||
"context": "为日本漫画足球小将翼的一个角色,自小父母离异,与父亲一起四处为家,每个地方也是待一会便离开,但他仍然能够保持优秀的学业成绩。在第一次南葛市生活时,与同样就读于南葛小学的大空翼为黄金拍档,曾效力球队包括南葛小学、南葛高中、日本少年队、日本青年军、日本奥运队。效力日本青年军期间,因救同母异父的妹妹导致被车撞至断脚,在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。基本资料绰号:球场上的艺术家出身地:日本南葛市诞生日:5月5日星座:金牛座球衣号码:11担任位置:中场、攻击中场、右中场擅长脚:右脚所属队伍:盘田山叶故事发展岬太郎在小学期间不断转换学校,在南葛小学就读时在全国大赛中夺得冠军;国中三年随父亲孤单地在法国留学;回国后三年的高中生涯一直输给日本王牌射手日向小次郎率领的东邦学院。在【Golden 23】年代,大空翼、日向小次郎等名将均转战海外,他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。必杀技1. 回力刀射门2. S. S. S. 射门3. 双人射门(与大空翼合作)", | |||||
"qas": [ | |||||
{ | |||||
"question": "岬太郎在第一次南葛市生活时的搭档是谁?", | |||||
"id": "TRAIN_756_QUERY_0", | |||||
"answers": [ | |||||
{ | |||||
"text": "大空翼", | |||||
"answer_start": 84 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "日本队夺得世青冠军,岬太郎发挥了什么作用?", | |||||
"id": "TRAIN_756_QUERY_1", | |||||
"answers": [ | |||||
{ | |||||
"text": "在决赛周只在决赛的下半场十五分钟开始上场,成为日本队夺得世青冠军的其中一名功臣。", | |||||
"answer_start": 156 | |||||
} | |||||
] | |||||
}, | |||||
{ | |||||
"question": "岬太郎与谁一起组成了「3M」组合?", | |||||
"id": "TRAIN_756_QUERY_2", | |||||
"answers": [ | |||||
{ | |||||
"text": "他与松山光、三杉淳组成了「3M」组合(松山光Hikaru Matsuyama、岬太郎Taro Misaki、三杉淳Jyun Misugi)。", | |||||
"answer_start": 391 | |||||
} | |||||
] | |||||
} | |||||
] | |||||
} | |||||
], | |||||
"id": "TRAIN_756", | |||||
"title": "岬太郎" | |||||
} | |||||
] | |||||
} |
@@ -0,0 +1,14 @@ | |||||
import unittest | |||||
from fastNLP.io.loader.qa import CMRC2018Loader | |||||
class TestCMRC2018Loader(unittest.TestCase): | |||||
def test__load(self): | |||||
loader = CMRC2018Loader() | |||||
dataset = loader._load('test/data_for_tests/io/cmrc/train.json') | |||||
print(dataset) | |||||
def test_load(self): | |||||
loader = CMRC2018Loader() | |||||
data_bundle = loader.load('test/data_for_tests/io/cmrc/') | |||||
print(data_bundle) |
@@ -0,0 +1,24 @@ | |||||
import unittest | |||||
from fastNLP.io.pipe.qa import CMRC2018BertPipe | |||||
from fastNLP.io.loader.qa import CMRC2018Loader | |||||
class CMRC2018PipeTest(unittest.TestCase): | |||||
def test_process(self): | |||||
data_bundle = CMRC2018Loader().load('test/data_for_tests/io/cmrc/') | |||||
pipe = CMRC2018BertPipe() | |||||
data_bundle = pipe.process(data_bundle) | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
for ins in dataset: | |||||
if 'target_start' in ins: | |||||
# 抓到的答案是对应上的 | |||||
start_index = ins['target_start'] | |||||
end_index = ins['target_end']+1 | |||||
extract_answer = ''.join(ins['raw_chars'][start_index:end_index]) | |||||
self.assertEqual(extract_answer, ins['answers'][0]) | |||||
# 测试context_len是对的 | |||||
raw_chars = ins['raw_chars'] | |||||
expect_len = raw_chars.index('[SEP]') | |||||
self.assertEqual(expect_len, ins['context_len']) |
@@ -107,41 +107,37 @@ class TestBert(unittest.TestCase): | |||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3)) | self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3)) | ||||
def test_bert_4(self): | def test_bert_4(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | ||||
include_cls_sep=True) | |||||
include_cls_sep=False) | |||||
model = BertForQuestionAnswering(embed) | model = BertForQuestionAnswering(embed) | ||||
input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) | input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) | ||||
pred = model(input_ids) | pred = model(input_ids) | ||||
self.assertTrue(isinstance(pred, dict)) | self.assertTrue(isinstance(pred, dict)) | ||||
self.assertTrue(Const.OUTPUTS(0) in pred) | |||||
self.assertTrue(Const.OUTPUTS(1) in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 5)) | |||||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 5)) | |||||
self.assertTrue('pred_start' in pred) | |||||
self.assertTrue('pred_end' in pred) | |||||
self.assertEqual(tuple(pred['pred_start'].shape), (2, 3)) | |||||
self.assertEqual(tuple(pred['pred_end'].shape), (2, 3)) | |||||
model = BertForQuestionAnswering(embed, 7) | |||||
pred = model(input_ids) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertEqual(len(pred), 7) | |||||
def test_bert_for_question_answering_train(self): | |||||
from fastNLP import CMRC2018Loss | |||||
from fastNLP.io import CMRC2018BertPipe | |||||
from fastNLP import Trainer | |||||
def test_bert_4_w(self): | |||||
data_bundle = CMRC2018BertPipe().process_from_file('test/data_for_tests/io/cmrc') | |||||
data_bundle.rename_field('chars', 'words') | |||||
train_data = data_bundle.get_dataset('train') | |||||
vocab = data_bundle.get_vocab('words') | |||||
vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | ||||
include_cls_sep=False) | |||||
with self.assertWarns(Warning): | |||||
model = BertForQuestionAnswering(embed) | |||||
input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]]) | |||||
include_cls_sep=False, auto_truncate=True) | |||||
model = BertForQuestionAnswering(embed) | |||||
loss = CMRC2018Loss() | |||||
pred = model.predict(input_ids) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUTS(1) in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2,)) | |||||
trainer = Trainer(train_data, model, loss=loss, use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
def test_bert_5(self): | def test_bert_5(self): | ||||