@@ -24,7 +24,7 @@ from .utils import seq_len_to_mask | |||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import warnings | import warnings | ||||
from typing import Union | |||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | """ | ||||
@@ -337,15 +337,18 @@ class AccuracyMetric(MetricBase): | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if seq_len is not None: | |||||
masks = seq_len_to_mask(seq_len=seq_len) | |||||
if seq_len is not None and target.dim()>1: | |||||
max_len = target.size(1) | |||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||||
else: | else: | ||||
masks = None | masks = None | ||||
if pred.size() == target.size(): | |||||
if pred.dim() == target.dim(): | |||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1: | |||||
elif pred.dim() == target.dim() + 1: | |||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
if seq_len is None: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||||
else: | else: | ||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
@@ -493,20 +496,63 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | ||||
def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): | |||||
def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||||
""" | |||||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||||
:return: | |||||
""" | |||||
tag_set = set() | |||||
unk_token = '<unk>' | |||||
pad_token = '<pad>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
unk_token = tag_vocab.unknown | |||||
pad_token = tag_vocab.padding | |||||
tag_vocab = tag_vocab.idx2word | |||||
for idx, tag in tag_vocab.items(): | |||||
if tag in (unk_token, pad_token): | |||||
continue | |||||
tag = tag[:1].lower() | |||||
tag_set.add(tag) | |||||
bmes_tag_set = set('bmes') | |||||
if tag_set == bmes_tag_set: | |||||
return 'bmes' | |||||
bio_tag_set = set('bio') | |||||
if tag_set == bio_tag_set: | |||||
return 'bio' | |||||
bmeso_tag_set = set('bmeso') | |||||
if tag_set == bmeso_tag_set: | |||||
return 'bmeso' | |||||
bioes_tag_set = set('bioes') | |||||
if tag_set == bioes_tag_set: | |||||
return 'bioes' | |||||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||||
def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): | |||||
""" | """ | ||||
检查vocab中的tag是否与encoding_type是匹配的 | 检查vocab中的tag是否与encoding_type是匹配的 | ||||
:param vocab: target的Vocabulary | |||||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||||
:param encoding_type: bio, bmes, bioes, bmeso | :param encoding_type: bio, bmes, bioes, bmeso | ||||
:return: | :return: | ||||
""" | """ | ||||
tag_set = set() | tag_set = set() | ||||
for tag, idx in vocab: | |||||
if idx in (vocab.unknown_idx, vocab.padding_idx): | |||||
unk_token = '<unk>' | |||||
pad_token = '<pad>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
unk_token = tag_vocab.unknown | |||||
pad_token = tag_vocab.padding | |||||
tag_vocab = tag_vocab.idx2word | |||||
for idx, tag in tag_vocab.items(): | |||||
if tag in (unk_token, pad_token): | |||||
continue | continue | ||||
tag = tag[:1].lower() | tag = tag[:1].lower() | ||||
tag_set.add(tag) | tag_set.add(tag) | ||||
tags = encoding_type | tags = encoding_type | ||||
for tag in tag_set: | for tag in tag_set: | ||||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | ||||
@@ -549,7 +595,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | ||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | ||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | ||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | ||||
个label | 个label | ||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | ||||
@@ -560,18 +606,21 @@ class SpanFPreRecMetric(MetricBase): | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | |||||
only_gross=True, f_type='micro', beta=1): | only_gross=True, f_type='micro', beta=1): | ||||
encoding_type = encoding_type.lower() | |||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
self.encoding_type = encoding_type | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||||
if encoding_type: | |||||
encoding_type = encoding_type.lower() | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||||
self.encoding_type = encoding_type | |||||
else: | |||||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
if self.encoding_type == 'bmes': | if self.encoding_type == 'bmes': | ||||
self.tag_to_span_func = _bmes_tag_to_spans | self.tag_to_span_func = _bmes_tag_to_spans | ||||
elif self.encoding_type == 'bio': | elif self.encoding_type == 'bio': | ||||
@@ -581,7 +630,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
elif self.encoding_type == 'bioes': | elif self.encoding_type == 'bioes': | ||||
self.tag_to_span_func = _bioes_tag_to_spans | self.tag_to_span_func = _bioes_tag_to_spans | ||||
else: | else: | ||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
@@ -39,7 +39,7 @@ def _check_build_vocab(func): | |||||
@wraps(func) # to solve missing docstring | @wraps(func) # to solve missing docstring | ||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None or self.rebuild is True: | |||||
if self._word2idx is None or self.rebuild is True: | |||||
self.build_vocab() | self.build_vocab() | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
@@ -95,12 +95,30 @@ class Vocabulary(object): | |||||
self.word_count = Counter() | self.word_count = Counter() | ||||
self.unknown = unknown | self.unknown = unknown | ||||
self.padding = padding | self.padding = padding | ||||
self.word2idx = None | |||||
self.idx2word = None | |||||
self._word2idx = None | |||||
self._idx2word = None | |||||
self.rebuild = True | self.rebuild = True | ||||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | # 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | ||||
self._no_create_word = Counter() | self._no_create_word = Counter() | ||||
@property | |||||
@_check_build_vocab | |||||
def word2idx(self): | |||||
return self._word2idx | |||||
@word2idx.setter | |||||
def word2idx(self, value): | |||||
self._word2idx = value | |||||
@property | |||||
@_check_build_vocab | |||||
def idx2word(self): | |||||
return self._idx2word | |||||
@idx2word.setter | |||||
def idx2word(self, value): | |||||
self._word2idx = value | |||||
@_check_build_status | @_check_build_status | ||||
def update(self, word_lst, no_create_entry=False): | def update(self, word_lst, no_create_entry=False): | ||||
"""依次增加序列中词在词典中的出现频率 | """依次增加序列中词在词典中的出现频率 | ||||
@@ -187,21 +205,21 @@ class Vocabulary(object): | |||||
但已经记录在词典中的词, 不会改变对应的 `int` | 但已经记录在词典中的词, 不会改变对应的 `int` | ||||
""" | """ | ||||
if self.word2idx is None: | |||||
self.word2idx = {} | |||||
if self._word2idx is None: | |||||
self._word2idx = {} | |||||
if self.padding is not None: | if self.padding is not None: | ||||
self.word2idx[self.padding] = len(self.word2idx) | |||||
self._word2idx[self.padding] = len(self._word2idx) | |||||
if self.unknown is not None: | if self.unknown is not None: | ||||
self.word2idx[self.unknown] = len(self.word2idx) | |||||
self._word2idx[self.unknown] = len(self._word2idx) | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | words = filter(lambda kv: kv[1] >= self.min_freq, words) | ||||
if self.word2idx is not None: | |||||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||||
start_idx = len(self.word2idx) | |||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||||
if self._word2idx is not None: | |||||
words = filter(lambda kv: kv[0] not in self._word2idx, words) | |||||
start_idx = len(self._word2idx) | |||||
self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
self.rebuild = False | self.rebuild = False | ||||
return self | return self | ||||
@@ -211,12 +229,12 @@ class Vocabulary(object): | |||||
基于 `word to index` dict, 构建 `index to word` dict. | 基于 `word to index` dict, 构建 `index to word` dict. | ||||
""" | """ | ||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||||
self._idx2word = {i: w for w, i in self._word2idx.items()} | |||||
return self | return self | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.word2idx) | |||||
return len(self._word2idx) | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
@@ -226,7 +244,7 @@ class Vocabulary(object): | |||||
:param item: the word | :param item: the word | ||||
:return: True or False | :return: True or False | ||||
""" | """ | ||||
return item in self.word2idx | |||||
return item in self._word2idx | |||||
def has_word(self, w): | def has_word(self, w): | ||||
""" | """ | ||||
@@ -248,10 +266,10 @@ class Vocabulary(object): | |||||
vocab[w] | vocab[w] | ||||
""" | """ | ||||
if w in self.word2idx: | |||||
return self.word2idx[w] | |||||
if w in self._word2idx: | |||||
return self._word2idx[w] | |||||
if self.unknown is not None: | if self.unknown is not None: | ||||
return self.word2idx[self.unknown] | |||||
return self._word2idx[self.unknown] | |||||
else: | else: | ||||
raise ValueError("word `{}` not in vocabulary".format(w)) | raise ValueError("word `{}` not in vocabulary".format(w)) | ||||
@@ -405,7 +423,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if self.unknown is None: | if self.unknown is None: | ||||
return None | return None | ||||
return self.word2idx[self.unknown] | |||||
return self._word2idx[self.unknown] | |||||
@property | @property | ||||
@_check_build_vocab | @_check_build_vocab | ||||
@@ -415,7 +433,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if self.padding is None: | if self.padding is None: | ||||
return None | return None | ||||
return self.word2idx[self.padding] | |||||
return self._word2idx[self.padding] | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
@@ -425,7 +443,7 @@ class Vocabulary(object): | |||||
:param int idx: the index | :param int idx: the index | ||||
:return str word: the word | :return str word: the word | ||||
""" | """ | ||||
return self.idx2word[idx] | |||||
return self._idx2word[idx] | |||||
def clear(self): | def clear(self): | ||||
""" | """ | ||||
@@ -434,8 +452,8 @@ class Vocabulary(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
self.word_count.clear() | self.word_count.clear() | ||||
self.word2idx = None | |||||
self.idx2word = None | |||||
self._word2idx = None | |||||
self._idx2word = None | |||||
self.rebuild = True | self.rebuild = True | ||||
self._no_create_word.clear() | self._no_create_word.clear() | ||||
return self | return self | ||||
@@ -446,8 +464,8 @@ class Vocabulary(object): | |||||
""" | """ | ||||
len(self) # make sure vocab has been built | len(self) # make sure vocab has been built | ||||
state = self.__dict__.copy() | state = self.__dict__.copy() | ||||
# no need to pickle idx2word as it can be constructed from word2idx | |||||
del state['idx2word'] | |||||
# no need to pickle _idx2word as it can be constructed from _word2idx | |||||
del state['_idx2word'] | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
@@ -462,5 +480,5 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def __iter__(self): | def __iter__(self): | ||||
for word, index in self.word2idx.items(): | |||||
for word, index in self._word2idx.items(): | |||||
yield word, index | yield word, index |
@@ -8,7 +8,7 @@ __all__ = [ | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from typing import Union | |||||
class DataBundle: | class DataBundle: | ||||
""" | """ | ||||
@@ -191,7 +191,7 @@ class DataBundle: | |||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
return self | return self | ||||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True): | |||||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | |||||
""" | """ | ||||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | ||||
@@ -199,6 +199,7 @@ class DataBundle: | |||||
:param str new_field_name: | :param str new_field_name: | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 | |||||
:return: self | :return: self | ||||
""" | """ | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -206,15 +207,20 @@ class DataBundle: | |||||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
if rename_vocab: | |||||
if field_name in self.vocabs: | |||||
self.vocabs[new_field_name] = self.vocabs.pop(field_name) | |||||
return self | return self | ||||
def delete_field(self, field_name, ignore_miss_dataset=True): | |||||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | |||||
""" | """ | ||||
将DataBundle中所有DataSet中名为field_name的field删除掉. | 将DataBundle中所有DataSet中名为field_name的field删除掉. | ||||
:param str field_name: | :param str field_name: | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 | |||||
:return: self | :return: self | ||||
""" | """ | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -222,8 +228,39 @@ class DataBundle: | |||||
dataset.delete_field(field_name=field_name) | dataset.delete_field(field_name=field_name) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
if delete_vocab: | |||||
if field_name in self.vocabs: | |||||
self.vocabs.pop(field_name) | |||||
return self | return self | ||||
def iter_datasets(self)->Union[str, DataSet]: | |||||
""" | |||||
迭代data_bundle中的DataSet | |||||
Example:: | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
pass | |||||
:return: | |||||
""" | |||||
for name, dataset in self.datasets.items(): | |||||
yield name, dataset | |||||
def iter_vocabs(self)->Union[str, Vocabulary]: | |||||
""" | |||||
迭代data_bundle中的DataSet | |||||
Example: | |||||
for field_name, vocab in data_bundle.iter_vocabs(): | |||||
pass | |||||
:return: | |||||
""" | |||||
for field_name, vocab in self.vocabs.items(): | |||||
yield field_name, vocab | |||||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | ||||
""" | """ | ||||
对DataBundle中所有的dataset使用apply_field方法 | 对DataBundle中所有的dataset使用apply_field方法 | ||||
@@ -193,7 +193,7 @@ class OntoNotesNERPipe(_NERPipe): | |||||
""" | """ | ||||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | ||||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||||
.. csv-table:: | |||||
:header: "raw_words", "words", "target", "seq_len" | :header: "raw_words", "words", "target", "seq_len" | ||||
"[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 | ||||
@@ -207,7 +207,7 @@ class ArcBiaffine(nn.Module): | |||||
output = dep.matmul(self.U) | output = dep.matmul(self.U) | ||||
output = output.bmm(head.transpose(-1, -2)) | output = output.bmm(head.transpose(-1, -2)) | ||||
if self.has_bias: | if self.has_bias: | ||||
output += head.matmul(self.bias).unsqueeze(1) | |||||
output = output + head.matmul(self.bias).unsqueeze(1) | |||||
return output | return output | ||||
@@ -234,7 +234,7 @@ class LabelBilinear(nn.Module): | |||||
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | :return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | ||||
""" | """ | ||||
output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
output += self.lin(torch.cat([x1, x2], dim=2)) | |||||
output = output + self.lin(torch.cat([x1, x2], dim=2)) | |||||
return output | return output | ||||
@@ -363,7 +363,7 @@ class BiaffineParser(GraphParser): | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
mask = seq_len_to_mask(seq_len).long() | |||||
mask = seq_len_to_mask(seq_len, max_len=length).long() | |||||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | ||||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | ||||
@@ -435,10 +435,10 @@ class BiaffineParser(GraphParser): | |||||
""" | """ | ||||
batch_size, length, _ = pred1.shape | batch_size, length, _ = pred1.shape | ||||
mask = seq_len_to_mask(seq_len) | |||||
mask = seq_len_to_mask(seq_len, max_len=length) | |||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = pred1.clone() | _arc_pred = pred1.clone() | ||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(pred2, dim=2) | label_logits = F.log_softmax(pred2, dim=2) | ||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | ||||
@@ -446,9 +446,8 @@ class BiaffineParser(GraphParser): | |||||
arc_loss = arc_logits[batch_index, child_index, target1] | arc_loss = arc_logits[batch_index, child_index, target1] | ||||
label_loss = label_logits[batch_index, child_index, target2] | label_loss = label_logits[batch_index, child_index, target2] | ||||
byte_mask = flip_mask.byte() | |||||
arc_loss.masked_fill_(byte_mask, 0) | |||||
label_loss.masked_fill_(byte_mask, 0) | |||||
arc_loss = arc_loss.masked_fill(flip_mask, 0) | |||||
label_loss = label_loss.masked_fill(flip_mask, 0) | |||||
arc_nll = -arc_loss.mean() | arc_nll = -arc_loss.mean() | ||||
label_nll = -label_loss.mean() | label_nll = -label_loss.mean() | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
@@ -10,33 +10,45 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core.metrics import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type | |||||
from typing import Union | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): | |||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` | 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` | ||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||||
:param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||||
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。默认为None,通过vocab自动推断 | |||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | ||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | ||||
""" | """ | ||||
if isinstance(id2target, Vocabulary): | |||||
id2target = id2target.idx2word | |||||
num_tags = len(id2target) | |||||
if encoding_type is None: | |||||
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
else: | |||||
encoding_type = encoding_type.lower() | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||||
pad_token = '<pad>' | |||||
unk_token = '<unk>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
id_label_lst = list(tag_vocab.idx2word.items()) | |||||
pad_token = tag_vocab.padding | |||||
unk_token = tag_vocab.unknown | |||||
else: | |||||
id_label_lst = list(tag_vocab.items()) | |||||
num_tags = len(tag_vocab) | |||||
start_idx = num_tags | start_idx = num_tags | ||||
end_idx = num_tags + 1 | end_idx = num_tags + 1 | ||||
encoding_type = encoding_type.lower() | |||||
allowed_trans = [] | allowed_trans = [] | ||||
id_label_lst = list(id2target.items()) | |||||
if include_start_end: | if include_start_end: | ||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | ||||
def split_tag_label(from_label): | def split_tag_label(from_label): | ||||
from_label = from_label.lower() | from_label = from_label.lower() | ||||
if from_label in ['start', 'end']: | if from_label in ['start', 'end']: | ||||
@@ -48,11 +60,11 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) | |||||
return from_tag, from_label | return from_tag, from_label | ||||
for from_id, from_label in id_label_lst: | for from_id, from_label in id_label_lst: | ||||
if from_label in ['<pad>', '<unk>']: | |||||
if from_label in [pad_token, unk_token]: | |||||
continue | continue | ||||
from_tag, from_label = split_tag_label(from_label) | from_tag, from_label = split_tag_label(from_label) | ||||
for to_id, to_label in id_label_lst: | for to_id, to_label in id_label_lst: | ||||
if to_label in ['<pad>', '<unk>']: | |||||
if to_label in [pad_token, unk_token]: | |||||
continue | continue | ||||
to_tag, to_label = split_tag_label(to_label) | to_tag, to_label = split_tag_label(to_label) | ||||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
@@ -11,6 +11,12 @@ from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric | |||||
def _generate_tags(encoding_type, number_labels=4): | def _generate_tags(encoding_type, number_labels=4): | ||||
""" | |||||
:param encoding_type: 例如BIOES, BMES, BIO等 | |||||
:param number_labels: 多少个label,大于1 | |||||
:return: | |||||
""" | |||||
vocab = {} | vocab = {} | ||||
for i in range(number_labels): | for i in range(number_labels): | ||||
label = str(i) | label = str(i) | ||||
@@ -184,7 +190,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | ||||
class SpanF1PreRecMetric(unittest.TestCase): | |||||
class SpanFPreRecMetricTest(unittest.TestCase): | |||||
def test_case1(self): | def test_case1(self): | ||||
from fastNLP.core.metrics import _bmes_tag_to_spans | from fastNLP.core.metrics import _bmes_tag_to_spans | ||||
from fastNLP.core.metrics import _bio_tag_to_spans | from fastNLP.core.metrics import _bio_tag_to_spans | ||||
@@ -338,6 +344,39 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
for key, value in expected_metric.items(): | for key, value in expected_metric.items(): | ||||
self.assertAlmostEqual(value, metric_value[key], places=5) | self.assertAlmostEqual(value, metric_value[key], places=5) | ||||
def test_auto_encoding_type_infer(self): | |||||
# 检查是否可以自动check encode的类型 | |||||
vocabs = {} | |||||
import random | |||||
for encoding_type in ['bio', 'bioes', 'bmeso']: | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for i in range(random.randint(10, 100)): | |||||
label = str(random.randint(1, 10)) | |||||
for tag in encoding_type: | |||||
if tag!='o': | |||||
vocab.add_word(f'{tag}-{label}') | |||||
else: | |||||
vocab.add_word('o') | |||||
vocabs[encoding_type] = vocab | |||||
for e in ['bio', 'bioes', 'bmeso']: | |||||
with self.subTest(e=e): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) | |||||
assert metric.encoding_type == e | |||||
bmes_vocab = _generate_tags('bmes') | |||||
vocab = Vocabulary() | |||||
for tag, index in bmes_vocab.items(): | |||||
vocab.add_word(tag) | |||||
metric = SpanFPreRecMetric(vocab) | |||||
assert metric.encoding_type == 'bmes' | |||||
# 一些无法check的情况 | |||||
vocab = Vocabulary() | |||||
for i in range(10): | |||||
vocab.add_word(str(i)) | |||||
with self.assertRaises(Exception): | |||||
metric = SpanFPreRecMetric(vocab) | |||||
def test_encoding_type(self): | def test_encoding_type(self): | ||||
# 检查传入的tag_vocab与encoding_type不符合时,是否会报错 | # 检查传入的tag_vocab与encoding_type不符合时,是否会报错 | ||||
vocabs = {} | vocabs = {} | ||||
@@ -1,6 +1,6 @@ | |||||
import unittest | import unittest | ||||
from fastNLP import Vocabulary | |||||
class TestCRF(unittest.TestCase): | class TestCRF(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
@@ -14,7 +14,8 @@ class TestCRF(unittest.TestCase): | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | ||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | ||||
allowed_transitions(id2label, include_start_end=True) | allowed_transitions(id2label, include_start_end=True) | ||||
@@ -37,7 +38,100 @@ class TestCRF(unittest.TestCase): | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | ||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | ||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
def test_case11(self): | |||||
# 测试自动推断encoding类型 | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||||
allowed_transitions(id2label, include_start_end=True) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
def test_case12(self): | |||||
# 测试能否通过vocab生成转移矩阵 | |||||
from fastNLP.modules.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(vocab, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||||
vocab = Vocabulary() | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
allowed_transitions(vocab, include_start_end=True) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
self.assertSetEqual(expected_res, set( | |||||
allowed_transitions(vocab, include_start_end=True))) | |||||
def test_case2(self): | def test_case2(self): | ||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | ||||