diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 0dc601a3..b06e5459 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -24,7 +24,7 @@ from .utils import seq_len_to_mask from .vocabulary import Vocabulary from abc import abstractmethod import warnings - +from typing import Union class MetricBase(object): """ @@ -337,15 +337,18 @@ class AccuracyMetric(MetricBase): raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_len)}.") - if seq_len is not None: - masks = seq_len_to_mask(seq_len=seq_len) + if seq_len is not None and target.dim()>1: + max_len = target.size(1) + masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = None - if pred.size() == target.size(): + if pred.dim() == target.dim(): pass - elif len(pred.size()) == len(target.size()) + 1: + elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) + if seq_len is None: + warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") else: raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " f"size:{pred.size()}, target should have size: {pred.size()} or " @@ -493,20 +496,63 @@ def _bio_tag_to_spans(tags, ignore_labels=None): return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] -def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str): +def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: + """ + 给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio + + :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 + :return: + """ + tag_set = set() + unk_token = '' + pad_token = '' + if isinstance(tag_vocab, Vocabulary): + unk_token = tag_vocab.unknown + pad_token = tag_vocab.padding + tag_vocab = tag_vocab.idx2word + for idx, tag in tag_vocab.items(): + if tag in (unk_token, pad_token): + continue + tag = tag[:1].lower() + tag_set.add(tag) + + bmes_tag_set = set('bmes') + if tag_set == bmes_tag_set: + return 'bmes' + bio_tag_set = set('bio') + if tag_set == bio_tag_set: + return 'bio' + bmeso_tag_set = set('bmeso') + if tag_set == bmeso_tag_set: + return 'bmeso' + bioes_tag_set = set('bioes') + if tag_set == bioes_tag_set: + return 'bioes' + raise RuntimeError("encoding_type cannot be inferred automatically. Only support " + "'bio', 'bmes', 'bmeso', 'bioes' type.") + + +def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): """ 检查vocab中的tag是否与encoding_type是匹配的 - :param vocab: target的Vocabulary + :param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 :param encoding_type: bio, bmes, bioes, bmeso :return: """ tag_set = set() - for tag, idx in vocab: - if idx in (vocab.unknown_idx, vocab.padding_idx): + unk_token = '' + pad_token = '' + if isinstance(tag_vocab, Vocabulary): + unk_token = tag_vocab.unknown + pad_token = tag_vocab.padding + tag_vocab = tag_vocab.idx2word + for idx, tag in tag_vocab.items(): + if tag in (unk_token, pad_token): continue tag = tag[:1].lower() tag_set.add(tag) + tags = encoding_type for tag in tag_set: assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ @@ -549,7 +595,7 @@ class SpanFPreRecMetric(MetricBase): :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 - :param str encoding_type: 目前支持bio, bmes, bmeso, bioes + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 个label :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 @@ -560,18 +606,21 @@ class SpanFPreRecMetric(MetricBase): 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ - def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, + def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, only_gross=True, f_type='micro', beta=1): - - encoding_type = encoding_type.lower() - + if not isinstance(tag_vocab, Vocabulary): raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) - - self.encoding_type = encoding_type - _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) + + if encoding_type: + encoding_type = encoding_type.lower() + _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) + self.encoding_type = encoding_type + else: + self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) + if self.encoding_type == 'bmes': self.tag_to_span_func = _bmes_tag_to_spans elif self.encoding_type == 'bio': @@ -581,7 +630,7 @@ class SpanFPreRecMetric(MetricBase): elif self.encoding_type == 'bioes': self.tag_to_span_func = _bioes_tag_to_spans else: - raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") + raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") self.ignore_labels = ignore_labels self.f_type = f_type diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index cd4f2c0f..b0f9650a 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -39,7 +39,7 @@ def _check_build_vocab(func): @wraps(func) # to solve missing docstring def _wrapper(self, *args, **kwargs): - if self.word2idx is None or self.rebuild is True: + if self._word2idx is None or self.rebuild is True: self.build_vocab() return func(self, *args, **kwargs) @@ -95,12 +95,30 @@ class Vocabulary(object): self.word_count = Counter() self.unknown = unknown self.padding = padding - self.word2idx = None - self.idx2word = None + self._word2idx = None + self._idx2word = None self.rebuild = True # 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 self._no_create_word = Counter() - + + @property + @_check_build_vocab + def word2idx(self): + return self._word2idx + + @word2idx.setter + def word2idx(self, value): + self._word2idx = value + + @property + @_check_build_vocab + def idx2word(self): + return self._idx2word + + @idx2word.setter + def idx2word(self, value): + self._word2idx = value + @_check_build_status def update(self, word_lst, no_create_entry=False): """依次增加序列中词在词典中的出现频率 @@ -187,21 +205,21 @@ class Vocabulary(object): 但已经记录在词典中的词, 不会改变对应的 `int` """ - if self.word2idx is None: - self.word2idx = {} + if self._word2idx is None: + self._word2idx = {} if self.padding is not None: - self.word2idx[self.padding] = len(self.word2idx) + self._word2idx[self.padding] = len(self._word2idx) if self.unknown is not None: - self.word2idx[self.unknown] = len(self.word2idx) + self._word2idx[self.unknown] = len(self._word2idx) max_size = min(self.max_size, len(self.word_count)) if self.max_size else None words = self.word_count.most_common(max_size) if self.min_freq is not None: words = filter(lambda kv: kv[1] >= self.min_freq, words) - if self.word2idx is not None: - words = filter(lambda kv: kv[0] not in self.word2idx, words) - start_idx = len(self.word2idx) - self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) + if self._word2idx is not None: + words = filter(lambda kv: kv[0] not in self._word2idx, words) + start_idx = len(self._word2idx) + self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() self.rebuild = False return self @@ -211,12 +229,12 @@ class Vocabulary(object): 基于 `word to index` dict, 构建 `index to word` dict. """ - self.idx2word = {i: w for w, i in self.word2idx.items()} + self._idx2word = {i: w for w, i in self._word2idx.items()} return self @_check_build_vocab def __len__(self): - return len(self.word2idx) + return len(self._word2idx) @_check_build_vocab def __contains__(self, item): @@ -226,7 +244,7 @@ class Vocabulary(object): :param item: the word :return: True or False """ - return item in self.word2idx + return item in self._word2idx def has_word(self, w): """ @@ -248,10 +266,10 @@ class Vocabulary(object): vocab[w] """ - if w in self.word2idx: - return self.word2idx[w] + if w in self._word2idx: + return self._word2idx[w] if self.unknown is not None: - return self.word2idx[self.unknown] + return self._word2idx[self.unknown] else: raise ValueError("word `{}` not in vocabulary".format(w)) @@ -405,7 +423,7 @@ class Vocabulary(object): """ if self.unknown is None: return None - return self.word2idx[self.unknown] + return self._word2idx[self.unknown] @property @_check_build_vocab @@ -415,7 +433,7 @@ class Vocabulary(object): """ if self.padding is None: return None - return self.word2idx[self.padding] + return self._word2idx[self.padding] @_check_build_vocab def to_word(self, idx): @@ -425,7 +443,7 @@ class Vocabulary(object): :param int idx: the index :return str word: the word """ - return self.idx2word[idx] + return self._idx2word[idx] def clear(self): """ @@ -434,8 +452,8 @@ class Vocabulary(object): :return: """ self.word_count.clear() - self.word2idx = None - self.idx2word = None + self._word2idx = None + self._idx2word = None self.rebuild = True self._no_create_word.clear() return self @@ -446,8 +464,8 @@ class Vocabulary(object): """ len(self) # make sure vocab has been built state = self.__dict__.copy() - # no need to pickle idx2word as it can be constructed from word2idx - del state['idx2word'] + # no need to pickle _idx2word as it can be constructed from _word2idx + del state['_idx2word'] return state def __setstate__(self, state): @@ -462,5 +480,5 @@ class Vocabulary(object): @_check_build_vocab def __iter__(self): - for word, index in self.word2idx.items(): + for word, index in self._word2idx.items(): yield word, index diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index f30add34..3e7f39d3 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -8,7 +8,7 @@ __all__ = [ from ..core.dataset import DataSet from ..core.vocabulary import Vocabulary - +from typing import Union class DataBundle: """ @@ -191,7 +191,7 @@ class DataBundle: raise KeyError(f"{field_name} not found DataSet:{name}.") return self - def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True): + def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): """ 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. @@ -199,6 +199,7 @@ class DataBundle: :param str new_field_name: :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 + :param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 :return: self """ for name, dataset in self.datasets.items(): @@ -206,15 +207,20 @@ class DataBundle: dataset.rename_field(field_name=field_name, new_field_name=new_field_name) elif not ignore_miss_dataset: raise KeyError(f"{field_name} not found DataSet:{name}.") + if rename_vocab: + if field_name in self.vocabs: + self.vocabs[new_field_name] = self.vocabs.pop(field_name) + return self - def delete_field(self, field_name, ignore_miss_dataset=True): + def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): """ 将DataBundle中所有DataSet中名为field_name的field删除掉. :param str field_name: :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 + :param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 :return: self """ for name, dataset in self.datasets.items(): @@ -222,8 +228,39 @@ class DataBundle: dataset.delete_field(field_name=field_name) elif not ignore_miss_dataset: raise KeyError(f"{field_name} not found DataSet:{name}.") + if delete_vocab: + if field_name in self.vocabs: + self.vocabs.pop(field_name) return self + def iter_datasets(self)->Union[str, DataSet]: + """ + 迭代data_bundle中的DataSet + + Example:: + + for name, dataset in data_bundle.iter_datasets(): + pass + + :return: + """ + for name, dataset in self.datasets.items(): + yield name, dataset + + def iter_vocabs(self)->Union[str, Vocabulary]: + """ + 迭代data_bundle中的DataSet + + Example: + + for field_name, vocab in data_bundle.iter_vocabs(): + pass + + :return: + """ + for field_name, vocab in self.vocabs.items(): + yield field_name, vocab + def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): """ 对DataBundle中所有的dataset使用apply_field方法 diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index eb7d4909..2edc9008 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -193,7 +193,7 @@ class OntoNotesNERPipe(_NERPipe): """ 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 - .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + .. csv-table:: :header: "raw_words", "words", "target", "seq_len" "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2 diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index bead09fc..6b0829bd 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -207,7 +207,7 @@ class ArcBiaffine(nn.Module): output = dep.matmul(self.U) output = output.bmm(head.transpose(-1, -2)) if self.has_bias: - output += head.matmul(self.bias).unsqueeze(1) + output = output + head.matmul(self.bias).unsqueeze(1) return output @@ -234,7 +234,7 @@ class LabelBilinear(nn.Module): :return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 """ output = self.bilinear(x1, x2) - output += self.lin(torch.cat([x1, x2], dim=2)) + output = output + self.lin(torch.cat([x1, x2], dim=2)) return output @@ -363,7 +363,7 @@ class BiaffineParser(GraphParser): # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask - mask = seq_len_to_mask(seq_len).long() + mask = seq_len_to_mask(seq_len, max_len=length).long() word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] @@ -435,10 +435,10 @@ class BiaffineParser(GraphParser): """ batch_size, length, _ = pred1.shape - mask = seq_len_to_mask(seq_len) + mask = seq_len_to_mask(seq_len, max_len=length) flip_mask = (mask == 0) _arc_pred = pred1.clone() - _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) + _arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) arc_logits = F.log_softmax(_arc_pred, dim=2) label_logits = F.log_softmax(pred2, dim=2) batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) @@ -446,9 +446,8 @@ class BiaffineParser(GraphParser): arc_loss = arc_logits[batch_index, child_index, target1] label_loss = label_logits[batch_index, child_index, target2] - byte_mask = flip_mask.byte() - arc_loss.masked_fill_(byte_mask, 0) - label_loss.masked_fill_(byte_mask, 0) + arc_loss = arc_loss.masked_fill(flip_mask, 0) + label_loss = label_loss.masked_fill(flip_mask, 0) arc_nll = -arc_loss.mean() label_nll = -label_loss.mean() return arc_nll + label_nll diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index f63d46e3..c13ea50c 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -10,33 +10,45 @@ from torch import nn from ..utils import initial_parameter from ...core.vocabulary import Vocabulary +from ...core.metrics import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type +from typing import Union - -def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): +def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type=None, include_start_end=False): """ 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 - :param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 - "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 - :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 + :param ~fastNLP.Vocabulary,dict tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", + tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 + :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。默认为None,通过vocab自动推断 :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 """ - if isinstance(id2target, Vocabulary): - id2target = id2target.idx2word - num_tags = len(id2target) + if encoding_type is None: + encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) + else: + encoding_type = encoding_type.lower() + _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) + + pad_token = '' + unk_token = '' + + if isinstance(tag_vocab, Vocabulary): + id_label_lst = list(tag_vocab.idx2word.items()) + pad_token = tag_vocab.padding + unk_token = tag_vocab.unknown + else: + id_label_lst = list(tag_vocab.items()) + + num_tags = len(tag_vocab) start_idx = num_tags end_idx = num_tags + 1 - encoding_type = encoding_type.lower() allowed_trans = [] - id_label_lst = list(id2target.items()) if include_start_end: id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] - def split_tag_label(from_label): from_label = from_label.lower() if from_label in ['start', 'end']: @@ -48,11 +60,11 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) return from_tag, from_label for from_id, from_label in id_label_lst: - if from_label in ['', '']: + if from_label in [pad_token, unk_token]: continue from_tag, from_label = split_tag_label(from_label) for to_id, to_label in id_label_lst: - if to_label in ['', '']: + if to_label in [pad_token, unk_token]: continue to_tag, to_label = split_tag_label(to_label) if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 5a7c55cf..8a472a62 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -11,6 +11,12 @@ from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric def _generate_tags(encoding_type, number_labels=4): + """ + + :param encoding_type: 例如BIOES, BMES, BIO等 + :param number_labels: 多少个label,大于1 + :return: + """ vocab = {} for i in range(number_labels): label = str(i) @@ -184,7 +190,7 @@ class TestAccuracyMetric(unittest.TestCase): self.assertDictEqual(metric.get_metric(), {'acc': 1.}) -class SpanF1PreRecMetric(unittest.TestCase): +class SpanFPreRecMetricTest(unittest.TestCase): def test_case1(self): from fastNLP.core.metrics import _bmes_tag_to_spans from fastNLP.core.metrics import _bio_tag_to_spans @@ -338,6 +344,39 @@ class SpanF1PreRecMetric(unittest.TestCase): for key, value in expected_metric.items(): self.assertAlmostEqual(value, metric_value[key], places=5) + def test_auto_encoding_type_infer(self): + # 检查是否可以自动check encode的类型 + vocabs = {} + import random + for encoding_type in ['bio', 'bioes', 'bmeso']: + vocab = Vocabulary(unknown=None, padding=None) + for i in range(random.randint(10, 100)): + label = str(random.randint(1, 10)) + for tag in encoding_type: + if tag!='o': + vocab.add_word(f'{tag}-{label}') + else: + vocab.add_word('o') + vocabs[encoding_type] = vocab + for e in ['bio', 'bioes', 'bmeso']: + with self.subTest(e=e): + metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) + assert metric.encoding_type == e + + bmes_vocab = _generate_tags('bmes') + vocab = Vocabulary() + for tag, index in bmes_vocab.items(): + vocab.add_word(tag) + metric = SpanFPreRecMetric(vocab) + assert metric.encoding_type == 'bmes' + + # 一些无法check的情况 + vocab = Vocabulary() + for i in range(10): + vocab.add_word(str(i)) + with self.assertRaises(Exception): + metric = SpanFPreRecMetric(vocab) + def test_encoding_type(self): # 检查传入的tag_vocab与encoding_type不符合时,是否会报错 vocabs = {} diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 647af7d3..94b4ab7a 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -1,6 +1,6 @@ import unittest - +from fastNLP import Vocabulary class TestCRF(unittest.TestCase): def test_case1(self): @@ -14,7 +14,8 @@ class TestCRF(unittest.TestCase): id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) + self.assertSetEqual(expected_res, set( + allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) id2label = {0: 'B', 1: 'I', 2:'O', 3: '', 4:""} allowed_transitions(id2label, include_start_end=True) @@ -37,7 +38,100 @@ class TestCRF(unittest.TestCase): expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) + self.assertSetEqual(expected_res, set( + allowed_transitions(id2label, include_start_end=True))) + + def test_case11(self): + # 测试自动推断encoding类型 + from fastNLP.modules.decoder.crf import allowed_transitions + + id2label = {0: 'B', 1: 'I', 2: 'O'} + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) + + id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + self.assertSetEqual(expected_res, set( + allowed_transitions(id2label, include_start_end=True))) + + id2label = {0: 'B', 1: 'I', 2: 'O', 3: '', 4: ""} + allowed_transitions(id2label, include_start_end=True) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + self.assertSetEqual(expected_res, set( + allowed_transitions(id2label, include_start_end=True))) + + def test_case12(self): + # 测试能否通过vocab生成转移矩阵 + from fastNLP.modules.decoder.crf import allowed_transitions + + id2label = {0: 'B', 1: 'I', 2: 'O'} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) + + id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + self.assertSetEqual(expected_res, set( + allowed_transitions(vocab, include_start_end=True))) + + id2label = {0: 'B', 1: 'I', 2: 'O', 3: '', 4: ""} + vocab = Vocabulary() + for idx, tag in id2label.items(): + vocab.add_word(tag) + allowed_transitions(vocab, include_start_end=True) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + self.assertSetEqual(expected_res, set(allowed_transitions(vocab, include_start_end=True))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + vocab = Vocabulary(unknown=None, padding=None) + for idx, tag in id2label.items(): + vocab.add_word(tag) + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + self.assertSetEqual(expected_res, set( + allowed_transitions(vocab, include_start_end=True))) + def test_case2(self): # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。