@@ -24,7 +24,7 @@ from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
import warnings | |||
from typing import Union | |||
class MetricBase(object): | |||
""" | |||
@@ -337,15 +337,18 @@ class AccuracyMetric(MetricBase): | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if seq_len is not None: | |||
masks = seq_len_to_mask(seq_len=seq_len) | |||
if seq_len is not None and target.dim()>1: | |||
max_len = target.size(1) | |||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||
else: | |||
masks = None | |||
if pred.size() == target.size(): | |||
if pred.dim() == target.dim(): | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1: | |||
elif pred.dim() == target.dim() + 1: | |||
pred = pred.argmax(dim=-1) | |||
if seq_len is None: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
@@ -493,20 +496,63 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): | |||
def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||
""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
bmes_tag_set = set('bmes') | |||
if tag_set == bmes_tag_set: | |||
return 'bmes' | |||
bio_tag_set = set('bio') | |||
if tag_set == bio_tag_set: | |||
return 'bio' | |||
bmeso_tag_set = set('bmeso') | |||
if tag_set == bmeso_tag_set: | |||
return 'bmeso' | |||
bioes_tag_set = set('bioes') | |||
if tag_set == bioes_tag_set: | |||
return 'bioes' | |||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): | |||
""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param vocab: target的Vocabulary | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
for tag, idx in vocab: | |||
if idx in (vocab.unknown_idx, vocab.padding_idx): | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
@@ -549,7 +595,7 @@ class SpanFPreRecMetric(MetricBase): | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
@@ -560,18 +606,21 @@ class SpanFPreRecMetric(MetricBase): | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | |||
only_gross=True, f_type='micro', beta=1): | |||
encoding_type = encoding_type.lower() | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
if encoding_type: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
self.encoding_type = encoding_type | |||
else: | |||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
@@ -581,7 +630,7 @@ class SpanFPreRecMetric(MetricBase): | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
@@ -39,7 +39,7 @@ def _check_build_vocab(func): | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.word2idx is None or self.rebuild is True: | |||
if self._word2idx is None or self.rebuild is True: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
@@ -95,12 +95,30 @@ class Vocabulary(object): | |||
self.word_count = Counter() | |||
self.unknown = unknown | |||
self.padding = padding | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | |||
self._no_create_word = Counter() | |||
@property | |||
@_check_build_vocab | |||
def word2idx(self): | |||
return self._word2idx | |||
@word2idx.setter | |||
def word2idx(self, value): | |||
self._word2idx = value | |||
@property | |||
@_check_build_vocab | |||
def idx2word(self): | |||
return self._idx2word | |||
@idx2word.setter | |||
def idx2word(self, value): | |||
self._word2idx = value | |||
@_check_build_status | |||
def update(self, word_lst, no_create_entry=False): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -187,21 +205,21 @@ class Vocabulary(object): | |||
但已经记录在词典中的词, 不会改变对应的 `int` | |||
""" | |||
if self.word2idx is None: | |||
self.word2idx = {} | |||
if self._word2idx is None: | |||
self._word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
self._word2idx[self.padding] = len(self._word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
self._word2idx[self.unknown] = len(self._word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
if self.word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||
start_idx = len(self.word2idx) | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
if self._word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self._word2idx, words) | |||
start_idx = len(self._word2idx) | |||
self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
return self | |||
@@ -211,12 +229,12 @@ class Vocabulary(object): | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
self._idx2word = {i: w for w, i in self._word2idx.items()} | |||
return self | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self.word2idx) | |||
return len(self._word2idx) | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
@@ -226,7 +244,7 @@ class Vocabulary(object): | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self.word2idx | |||
return item in self._word2idx | |||
def has_word(self, w): | |||
""" | |||
@@ -248,10 +266,10 @@ class Vocabulary(object): | |||
vocab[w] | |||
""" | |||
if w in self.word2idx: | |||
return self.word2idx[w] | |||
if w in self._word2idx: | |||
return self._word2idx[w] | |||
if self.unknown is not None: | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
else: | |||
raise ValueError("word `{}` not in vocabulary".format(w)) | |||
@@ -405,7 +423,7 @@ class Vocabulary(object): | |||
""" | |||
if self.unknown is None: | |||
return None | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
@property | |||
@_check_build_vocab | |||
@@ -415,7 +433,7 @@ class Vocabulary(object): | |||
""" | |||
if self.padding is None: | |||
return None | |||
return self.word2idx[self.padding] | |||
return self._word2idx[self.padding] | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
@@ -425,7 +443,7 @@ class Vocabulary(object): | |||
:param int idx: the index | |||
:return str word: the word | |||
""" | |||
return self.idx2word[idx] | |||
return self._idx2word[idx] | |||
def clear(self): | |||
""" | |||
@@ -434,8 +452,8 @@ class Vocabulary(object): | |||
:return: | |||
""" | |||
self.word_count.clear() | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
self._no_create_word.clear() | |||
return self | |||
@@ -446,8 +464,8 @@ class Vocabulary(object): | |||
""" | |||
len(self) # make sure vocab has been built | |||
state = self.__dict__.copy() | |||
# no need to pickle idx2word as it can be constructed from word2idx | |||
del state['idx2word'] | |||
# no need to pickle _idx2word as it can be constructed from _word2idx | |||
del state['_idx2word'] | |||
return state | |||
def __setstate__(self, state): | |||
@@ -462,5 +480,5 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def __iter__(self): | |||
for word, index in self.word2idx.items(): | |||
for word, index in self._word2idx.items(): | |||
yield word, index |
@@ -8,7 +8,7 @@ __all__ = [ | |||
from ..core.dataset import DataSet | |||
from ..core.vocabulary import Vocabulary | |||
from typing import Union | |||
class DataBundle: | |||
""" | |||
@@ -191,7 +191,7 @@ class DataBundle: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True): | |||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | |||
@@ -199,6 +199,7 @@ class DataBundle: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
@@ -206,15 +207,20 @@ class DataBundle: | |||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if rename_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs[new_field_name] = self.vocabs.pop(field_name) | |||
return self | |||
def delete_field(self, field_name, ignore_miss_dataset=True): | |||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field删除掉. | |||
:param str field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
@@ -222,8 +228,39 @@ class DataBundle: | |||
dataset.delete_field(field_name=field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if delete_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs.pop(field_name) | |||
return self | |||
def iter_datasets(self)->Union[str, DataSet]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example:: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
pass | |||
:return: | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
yield name, dataset | |||
def iter_vocabs(self)->Union[str, Vocabulary]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example: | |||
for field_name, vocab in data_bundle.iter_vocabs(): | |||
pass | |||
:return: | |||
""" | |||
for field_name, vocab in self.vocabs.items(): | |||
yield field_name, vocab | |||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply_field方法 | |||
@@ -193,7 +193,7 @@ class OntoNotesNERPipe(_NERPipe): | |||
""" | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
.. csv-table:: | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | |||
@@ -207,7 +207,7 @@ class ArcBiaffine(nn.Module): | |||
output = dep.matmul(self.U) | |||
output = output.bmm(head.transpose(-1, -2)) | |||
if self.has_bias: | |||
output += head.matmul(self.bias).unsqueeze(1) | |||
output = output + head.matmul(self.bias).unsqueeze(1) | |||
return output | |||
@@ -234,7 +234,7 @@ class LabelBilinear(nn.Module): | |||
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | |||
""" | |||
output = self.bilinear(x1, x2) | |||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||
output = output + self.lin(torch.cat([x1, x2], dim=2)) | |||
return output | |||
@@ -363,7 +363,7 @@ class BiaffineParser(GraphParser): | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
mask = seq_len_to_mask(seq_len).long() | |||
mask = seq_len_to_mask(seq_len, max_len=length).long() | |||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||
@@ -435,10 +435,10 @@ class BiaffineParser(GraphParser): | |||
""" | |||
batch_size, length, _ = pred1.shape | |||
mask = seq_len_to_mask(seq_len) | |||
mask = seq_len_to_mask(seq_len, max_len=length) | |||
flip_mask = (mask == 0) | |||
_arc_pred = pred1.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
label_logits = F.log_softmax(pred2, dim=2) | |||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
@@ -446,9 +446,8 @@ class BiaffineParser(GraphParser): | |||
arc_loss = arc_logits[batch_index, child_index, target1] | |||
label_loss = label_logits[batch_index, child_index, target2] | |||
byte_mask = flip_mask.byte() | |||
arc_loss.masked_fill_(byte_mask, 0) | |||
label_loss.masked_fill_(byte_mask, 0) | |||
arc_loss = arc_loss.masked_fill(flip_mask, 0) | |||
label_loss = label_loss.masked_fill(flip_mask, 0) | |||
arc_nll = -arc_loss.mean() | |||
label_nll = -label_loss.mean() | |||
return arc_nll + label_nll | |||
@@ -10,33 +10,45 @@ from torch import nn | |||
from ..utils import initial_parameter | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.metrics import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type | |||
from typing import Union | |||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): | |||
""" | |||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | |||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||
:param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。默认为None,通过vocab自动推断 | |||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
""" | |||
if isinstance(id2target, Vocabulary): | |||
id2target = id2target.idx2word | |||
num_tags = len(id2target) | |||
if encoding_type is None: | |||
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
else: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
pad_token = '<pad>' | |||
unk_token = '<unk>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
id_label_lst = list(tag_vocab.idx2word.items()) | |||
pad_token = tag_vocab.padding | |||
unk_token = tag_vocab.unknown | |||
else: | |||
id_label_lst = list(tag_vocab.items()) | |||
num_tags = len(tag_vocab) | |||
start_idx = num_tags | |||
end_idx = num_tags + 1 | |||
encoding_type = encoding_type.lower() | |||
allowed_trans = [] | |||
id_label_lst = list(id2target.items()) | |||
if include_start_end: | |||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
def split_tag_label(from_label): | |||
from_label = from_label.lower() | |||
if from_label in ['start', 'end']: | |||
@@ -48,11 +60,11 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) | |||
return from_tag, from_label | |||
for from_id, from_label in id_label_lst: | |||
if from_label in ['<pad>', '<unk>']: | |||
if from_label in [pad_token, unk_token]: | |||
continue | |||
from_tag, from_label = split_tag_label(from_label) | |||
for to_id, to_label in id_label_lst: | |||
if to_label in ['<pad>', '<unk>']: | |||
if to_label in [pad_token, unk_token]: | |||
continue | |||
to_tag, to_label = split_tag_label(to_label) | |||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
@@ -11,6 +11,12 @@ from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric | |||
def _generate_tags(encoding_type, number_labels=4): | |||
""" | |||
:param encoding_type: 例如BIOES, BMES, BIO等 | |||
:param number_labels: 多少个label,大于1 | |||
:return: | |||
""" | |||
vocab = {} | |||
for i in range(number_labels): | |||
label = str(i) | |||
@@ -184,7 +190,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | |||
class SpanF1PreRecMetric(unittest.TestCase): | |||
class SpanFPreRecMetricTest(unittest.TestCase): | |||
def test_case1(self): | |||
from fastNLP.core.metrics import _bmes_tag_to_spans | |||
from fastNLP.core.metrics import _bio_tag_to_spans | |||
@@ -338,6 +344,39 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
for key, value in expected_metric.items(): | |||
self.assertAlmostEqual(value, metric_value[key], places=5) | |||
def test_auto_encoding_type_infer(self): | |||
# 检查是否可以自动check encode的类型 | |||
vocabs = {} | |||
import random | |||
for encoding_type in ['bio', 'bioes', 'bmeso']: | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for i in range(random.randint(10, 100)): | |||
label = str(random.randint(1, 10)) | |||
for tag in encoding_type: | |||
if tag!='o': | |||
vocab.add_word(f'{tag}-{label}') | |||
else: | |||
vocab.add_word('o') | |||
vocabs[encoding_type] = vocab | |||
for e in ['bio', 'bioes', 'bmeso']: | |||
with self.subTest(e=e): | |||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||
assert metric.encoding_type == e | |||
bmes_vocab = _generate_tags('bmes') | |||
vocab = Vocabulary() | |||
for tag, index in bmes_vocab.items(): | |||
vocab.add_word(tag) | |||
metric = SpanFPreRecMetric(vocab) | |||
assert metric.encoding_type == 'bmes' | |||
# 一些无法check的情况 | |||
vocab = Vocabulary() | |||
for i in range(10): | |||
vocab.add_word(str(i)) | |||
with self.assertRaises(Exception): | |||
metric = SpanFPreRecMetric(vocab) | |||
def test_encoding_type(self): | |||
# 检查传入的tag_vocab与encoding_type不符合时,是否会报错 | |||
vocabs = {} | |||
@@ -1,6 +1,6 @@ | |||
import unittest | |||
from fastNLP import Vocabulary | |||
class TestCRF(unittest.TestCase): | |||
def test_case1(self): | |||
@@ -14,7 +14,8 @@ class TestCRF(unittest.TestCase): | |||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
allowed_transitions(id2label, include_start_end=True) | |||
@@ -37,7 +38,100 @@ class TestCRF(unittest.TestCase): | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
def test_case11(self): | |||
# 测试自动推断encoding类型 | |||
from fastNLP.modules.decoder.crf import allowed_transitions | |||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
allowed_transitions(id2label, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(id2label, include_start_end=True))) | |||
def test_case12(self): | |||
# 测试能否通过vocab生成转移矩阵 | |||
from fastNLP.modules.decoder.crf import allowed_transitions | |||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(vocab, include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
vocab = Vocabulary() | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
allowed_transitions(vocab, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
for idx, tag in id2label.items(): | |||
vocab.add_word(tag) | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
self.assertSetEqual(expected_res, set( | |||
allowed_transitions(vocab, include_start_end=True))) | |||
def test_case2(self): | |||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||