diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 5ae05dac..0ac5c503 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -9,8 +9,8 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.model_zoo import load_url from fastNLP.api.processor import ModelProcessor -from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader -from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader +from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader +from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler @@ -95,7 +95,7 @@ class POS(API): pipeline.append(tag_proc) pp = Pipeline(pipeline) - reader = ConlluPOSReader() + reader = ConllPOSReader() te_dataset = reader.load(filepath) evaluator = SeqLabelEvaluator2('word_seq_origin_len') @@ -168,7 +168,7 @@ class CWS(API): pipeline.insert(1, tag_proc) pp = Pipeline(pipeline) - reader = ConlluCWSReader() + reader = ConllCWSReader() # te_filename = '/home/hyan/ctb3/test.conllx' te_dataset = reader.load(filepath) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index b495ea70..afa8775b 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary class Processor(object): def __init__(self, field_name, new_added_field_name): + """ + + :param field_name: 处理哪个field + :param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field + """ self.field_name = field_name if new_added_field_name is None: self.new_added_field_name = field_name @@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor): class PreAppendProcessor(Processor): + """ + 向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 + [data] + instance[field_name] + + """ def __init__(self, data, field_name, new_added_field_name=None): super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) self.data = data @@ -102,6 +112,10 @@ class PreAppendProcessor(Processor): class SliceProcessor(Processor): + """ + 从某个field中只取部分内容。等价于instance[field_name][start:end:step] + + """ def __init__(self, start, end, step, field_name, new_added_field_name=None): super(SliceProcessor, self).__init__(field_name, new_added_field_name) for o in (start, end, step): @@ -114,7 +128,17 @@ class SliceProcessor(Processor): class Num2TagProcessor(Processor): + """ + 将一句话中的数字转换为某个tag。 + + """ def __init__(self, tag, field_name, new_added_field_name=None): + """ + + :param tag: str, 将数字转换为该tag + :param field_name: + :param new_added_field_name: + """ super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) self.tag = tag self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' @@ -135,6 +159,10 @@ class Num2TagProcessor(Processor): class IndexerProcessor(Processor): + """ + 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 + ['我', '是', xxx] + """ def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) @@ -163,19 +191,19 @@ class IndexerProcessor(Processor): class VocabProcessor(Processor): - """Build vocabulary with a field in the data set. + """ + 传入若干个DataSet以建立vocabulary。 """ - def __init__(self, field_name): + def __init__(self, field_name, min_freq=1, max_size=None): super(VocabProcessor, self).__init__(field_name, None) - self.vocab = Vocabulary() + self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) def process(self, *datasets): for dataset in datasets: assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - self.vocab.update(ins[self.field_name]) + dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) def get_vocab(self): self.vocab.build_vocab() @@ -183,6 +211,10 @@ class VocabProcessor(Processor): class SeqLenProcessor(Processor): + """ + 根据某个field新增一个sequence length的field。取该field的第一维 + + """ def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) self.is_input = is_input @@ -195,10 +227,15 @@ class SeqLenProcessor(Processor): return dataset +from fastNLP.core.utils import _build_args + class ModelProcessor(Processor): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): """ - 迭代模型并将结果的padding drop掉 + 传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. + model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 + (Batch_size, 1),则使用seqence length这个field进行unpad + TODO 这个类需要删除对seq_lens的依赖。 :param seq_len_field_name: :param batch_size: @@ -211,13 +248,18 @@ class ModelProcessor(Processor): def process(self, dataset): self.model.eval() assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) + data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) batch_output = defaultdict(list) + if hasattr(self.model, "predict"): + predict_func = self.model.predict + else: + predict_func = self.model.forward with torch.no_grad(): for batch_x, _ in data_iterator: - prediction = self.model.predict(**batch_x) - seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() + refined_batch_x = _build_args(predict_func, **batch_x) + prediction = predict_func(**refined_batch_x) + seq_lens = batch_x[self.seq_len_field_name].tolist() for key, value in prediction.items(): tmp_batch = [] @@ -246,6 +288,10 @@ class ModelProcessor(Processor): class Index2WordProcessor(Processor): + """ + 将DataSet中某个为index的field根据vocab转换为str + + """ def __init__(self, vocab, field_name, new_added_field_name): super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) self.vocab = vocab @@ -266,5 +312,5 @@ class SetIsTargetProcessor(Processor): def process(self, dataset): set_dict = {name: self.default for name in dataset.get_all_fields().keys()} set_dict.update(self.field_dict) - dataset.set_target(**set_dict) + dataset.set_target(*set_dict.keys()) return dataset diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 52dac2fc..a80e84de 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -254,7 +254,7 @@ class DataSet(object): :return results: if new_field_name is not passed, returned values of the function over all instances. """ results = [func(ins) for ins in self._inner_iter()] - if len(list(filter(lambda x: x is not None, results))) == 0: # all None + if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None raise ValueError("{} always return None.".format(get_func_signature(func=func))) extra_param = {} diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 07ebe3fe..2619a155 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import seq_lens_to_masks +from fastNLP.core.vocabulary import Vocabulary class MetricBase(object): @@ -62,11 +63,6 @@ class MetricBase(object): f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change its signature.") - # evaluate should not have varargs. - # if func_spect.varargs: - # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " - # f"positional argument.).") - def get_metric(self, reset=True): raise NotImplemented @@ -91,10 +87,9 @@ class MetricBase(object): This method will call self.evaluate method. Before calling self.evaluate, it will first check the validity of output_dict, target_dict - (1) whether self.evaluate has varargs, which is not supported. - (2) whether params needed by self.evaluate is not included in output_dict,target_dict. - (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict - (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) + (1) whether params needed by self.evaluate is not included in output_dict,target_dict. + (2) whether params needed by self.evaluate duplicate in pred_dict, target_dict + (3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering will be conducted.) @@ -275,6 +270,369 @@ class AccuracyMetric(MetricBase): self.total = 0 return evaluate_result +def bmes_tag_to_spans(tags, ignore_labels=None): + """ + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bmes_tag, label = tag[:1], tag[2:] + if bmes_tag in ('b', 's'): + spans.append((label, [idx, idx])) + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: + spans[-1][1][1] = idx + else: + spans.append((label, [idx, idx])) + prev_bmes_tag = bmes_tag + return [(span[0], (span[1][0], span[1][1])) + for span in spans + if span[0] not in ignore_labels + ] + +def bio_tag_to_spans(tags, ignore_labels=None): + """ + + :param tags: List[str], + :param ignore_labels: List[str], 在该list中的label将被忽略 + :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] + """ + ignore_labels = set(ignore_labels) if ignore_labels else set() + + spans = [] + prev_bio_tag = None + for idx, tag in enumerate(tags): + tag = tag.lower() + bio_tag, label = tag[:1], tag[2:] + if bio_tag == 'b': + spans.append((label, [idx, idx])) + elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: + spans[-1][1][1] = idx + elif bio_tag == 'o': # o tag does not count + pass + else: + spans.append((label, [idx, idx])) + prev_bio_tag = bio_tag + return [(span[0], (span[1][0], span[1][1])) + for span in spans + if span[0] not in ignore_labels + ] + + +class SpanFPreRecMetric(MetricBase): + """ + 在序列标注问题中,以span的方式计算F, pre, rec. + 最后得到的metric结果为 + { + 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 + 'pre': xxx, + 'rec':xxx + } + 若only_gross=False, 即还会返回各个label的metric统计值 + { + 'f': xxx, + 'pre': xxx, + 'rec':xxx, + 'f-label': xxx, + 'pre-label': xxx, + 'rec-label':xxx, + ... + } + + """ + def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, + only_gross=True, f_type='micro', beta=1): + """ + + :param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), + 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. + :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 + :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 + :param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 + :param encoding_type: str, 目前支持bio, bmes + :param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 + 个label + :param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 + label的f1, pre, rec + :param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': + 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) + :param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 + 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + """ + encoding_type = encoding_type.lower() + if encoding_type not in ('bio', 'bmes'): + raise ValueError("Only support 'bio' or 'bmes' type.") + if not isinstance(tag_vocab, Vocabulary): + raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) + if f_type not in ('micro', 'macro'): + raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) + + self.encoding_type = encoding_type + if self.encoding_type == 'bmes': + self.tag_to_span_func = bmes_tag_to_spans + elif self.encoding_type == 'bio': + self.tag_to_span_func = bio_tag_to_spans + self.ignore_labels = ignore_labels + self.f_type = f_type + self.beta = beta + self.beta_square = self.beta**2 + self.only_gross = only_gross + + super().__init__() + self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) + + self.tag_vocab = tag_vocab + + self._true_positives = defaultdict(int) + self._false_positives = defaultdict(int) + self._false_negatives = defaultdict(int) + + def evaluate(self, pred, target, seq_lens): + """ + A lot of design idea comes from allennlp's measure + :param pred: + :param target: + :param seq_lens: + :return: + """ + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if not isinstance(seq_lens, torch.Tensor): + raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_lens)}.") + + num_classes = pred.size(-1) + if (target >= num_classes).any(): + raise ValueError("A gold label passed to SpanBasedF1Metric contains an " + "id >= {}, the number of classes.".format(num_classes)) + + if pred.size() == target.size() and len(target.size()) == 2: + pass + elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: + pred = pred.argmax(dim=-1) + else: + raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + batch_size = pred.size(0) + for i in range(batch_size): + pred_tags = pred[i, :seq_lens[i]].tolist() + gold_tags = target[i, :seq_lens[i]].tolist() + + pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] + gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] + + pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) + gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) + + for span in pred_spans: + if span in gold_spans: + self._true_positives[span[0]] += 1 + gold_spans.remove(span) + else: + self._false_positives[span[0]] += 1 + for span in gold_spans: + self._false_negatives[span[0]] += 1 + + def get_metric(self, reset=True): + evaluate_result = {} + if not self.only_gross or self.f_type=='macro': + tags = set(self._false_negatives.keys()) + tags.update(set(self._false_positives.keys())) + tags.update(set(self._true_positives.keys())) + f_sum = 0 + pre_sum = 0 + rec_sum = 0 + for tag in tags: + tp = self._true_positives[tag] + fn = self._false_negatives[tag] + fp = self._false_positives[tag] + f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) + f_sum += f + pre_sum += pre + rec_sum + rec + if not self.only_gross and tag!='': # tag!=''防止无tag的情况 + f_key = 'f-{}'.format(tag) + pre_key = 'pre-{}'.format(tag) + rec_key = 'rec-{}'.format(tag) + evaluate_result[f_key] = f + evaluate_result[pre_key] = pre + evaluate_result[rec_key] = rec + + if self.f_type == 'macro': + evaluate_result['f'] = f_sum/len(tags) + evaluate_result['pre'] = pre_sum/len(tags) + evaluate_result['rec'] = rec_sum/len(tags) + + if self.f_type == 'micro': + f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), + sum(self._false_negatives.values()), + sum(self._false_positives.values())) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec + + if reset: + self._true_positives = defaultdict(int) + self._false_positives = defaultdict(int) + self._false_negatives = defaultdict(int) + + return evaluate_result + + def _compute_f_pre_rec(self, tp, fn, fp): + """ + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) + + return f, pre, rec + +class BMESF1PreRecMetric(MetricBase): + """ + 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, + next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B + | | next_B | next_M | next_E | next_S | end | + |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| + | start | 合法 | next_M=B | next_E=S | 合法 | - | + | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | + | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | + | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | + | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | + 举例: + prediction为BSEMS,会被认为是SSSSS. + + 本Metric不检验target的合法性,请务必保证target的合法性。 + pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 + target形状为 (batch_size, max_len) + seq_lens形状为 (batch_size, ) + + """ + + def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): + """ + 需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 + + :param b_idx: int, Begin标签所对应的tag idx. + :param m_idx: int, Middle标签所对应的tag idx. + :param e_idx: int, End标签所对应的tag idx. + :param s_idx: int, Single标签所对应的tag idx + :param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 + :param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 + :param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 + """ + super().__init__() + + self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) + + self.yt_wordnum = 0 + self.yp_wordnum = 0 + self.corr_num = 0 + + self.b_idx = b_idx + self.m_idx = m_idx + self.e_idx = e_idx + self.s_idx = s_idx + # 还原init处介绍的矩阵 + self._valida_matrix = { + -1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx + self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], + self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], + self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], + self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], + } + + def _validate_tags(self, tags): + """ + 给定一个tag的Tensor,返回合法tag + + :param tags: Tensor, shape: (seq_len, ) + :return: 返回修改为合法tag的list + """ + assert len(tags)!=0 + assert isinstance(tags, torch.Tensor) and len(tags.size())==1 + padded_tags = [-1, *tags.tolist(), -1] + for idx in range(len(padded_tags)-1): + cur_tag = padded_tags[idx] + if cur_tag not in self._valida_matrix: + cur_tag = self.s_idx + if padded_tags[idx+1] not in self._valida_matrix: + padded_tags[idx+1] = self.s_idx + next_tag = padded_tags[idx+1] + shift_tag = self._valida_matrix[cur_tag][next_tag] + if shift_tag[0]!=-1: + padded_tags[idx+shift_tag[0]] = shift_tag[1] + + return padded_tags[1:-1] + + def evaluate(self, pred, target, seq_lens): + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if not isinstance(seq_lens, torch.Tensor): + raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_lens)}.") + + if pred.size() == target.size() and len(target.size()) == 2: + pass + elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: + pred = pred.argmax(dim=-1) + else: + raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + for idx in range(len(pred)): + seq_len = seq_lens[idx] + target_tags = target[idx][:seq_len].tolist() + pred_tags = pred[idx][:seq_len] + pred_tags = self._validate_tags(pred_tags) + start_idx = 0 + for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): + if t_tag in (self.s_idx, self.e_idx): + self.yt_wordnum += 1 + corr_flag = True + for i in range(start_idx, t_idx+1): + if target_tags[i]!=pred_tags[i]: + corr_flag = False + if corr_flag: + self.corr_num += 1 + start_idx = t_idx + 1 + if p_tag in (self.s_idx, self.e_idx): + self.yp_wordnum += 1 + + def get_metric(self, reset=True): + P = self.corr_num / (self.yp_wordnum + 1e-12) + R = self.corr_num / (self.yt_wordnum + 1e-12) + F = 2 * P * R / (P + R + 1e-12) + evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} + if reset: + self.yp_wordnum = 0 + self.yt_wordnum = 0 + self.corr_num = 0 + return evaluate_result + def _prepare_metrics(metrics): """ diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index bf32fa6c..ec2772e8 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -31,9 +31,8 @@ class Trainer(object): """ def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, - validate_every=-1, dev_data=None, use_cuda=False, save_path=None, - optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, - metric_key=None, sampler=RandomSampler(), use_tqdm=True): + validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), + check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): """ :param DataSet train_data: the training data diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 55d3faa4..baa8c403 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -19,26 +19,149 @@ def seq_len_to_byte_mask(seq_lens): mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) return mask +def allowed_transitions(id2label, encoding_type='bio'): + """ + + :param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 + :param encoding_type: str, 支持"bio", "bmes"。 + :return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 + 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). + start_idx=len(id2label), end_idx=len(id2label)+1。 + """ + num_tags = len(id2label) + start_idx = num_tags + end_idx = num_tags + 1 + encoding_type = encoding_type.lower() + allowed_trans = [] + id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] + def split_tag_label(from_label): + from_label = from_label.lower() + if from_label in ['start', 'end']: + from_tag = from_label + from_label = '' + else: + from_tag = from_label[:1] + from_label = from_label[2:] + return from_tag, from_label + + for from_id, from_label in id_label_lst: + if from_label in ['', '']: + continue + from_tag, from_label = split_tag_label(from_label) + for to_id, to_label in id_label_lst: + if to_label in ['', '']: + continue + to_tag, to_label = split_tag_label(to_label) + if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): + allowed_trans.append((from_id, to_id)) + return allowed_trans + +def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): + """ + + :param encoding_type: str, 支持"BIO", "BMES"。 + :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param from_label: str, 比如"PER", "LOC"等label + :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param to_label: str, 比如"PER", "LOC"等label + :return: bool,能否跃迁 + """ + if to_tag=='start' or from_tag=='end': + return False + encoding_type = encoding_type.lower() + if encoding_type == 'bio': + """ + 第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 + +-------+---+---+---+-------+-----+ + | | B | I | O | start | end | + +-------+---+---+---+-------+-----+ + | B | y | - | y | n | y | + +-------+---+---+---+-------+-----+ + | I | y | - | y | n | y | + +-------+---+---+---+-------+-----+ + | O | y | n | y | n | y | + +-------+---+---+---+-------+-----+ + | start | y | n | y | n | n | + +-------+---+---+---+-------+-----+ + | end | n | n | n | n | n | + +-------+---+---+---+-------+-----+ + """ + if from_tag == 'start': + return to_tag in ('b', 'o') + elif from_tag in ['b', 'i']: + return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) + elif from_tag == 'o': + return to_tag in ['end', 'b', 'o'] + else: + raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) + + elif encoding_type == 'bmes': + """ + 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 + +-------+---+---+---+---+-------+-----+ + | | B | M | E | S | start | end | + +-------+---+---+---+---+-------+-----+ + | B | n | - | - | n | n | n | + +-------+---+---+---+---+-------+-----+ + | M | n | - | - | n | n | n | + +-------+---+---+---+---+-------+-----+ + | E | y | n | n | y | n | y | + +-------+---+---+---+---+-------+-----+ + | S | y | n | n | y | n | y | + +-------+---+---+---+---+-------+-----+ + | start | y | n | n | y | n | n | + +-------+---+---+---+---+-------+-----+ + | end | n | n | n | n | n | n | + +-------+---+---+---+---+-------+-----+ + """ + if from_tag == 'start': + return to_tag in ['b', 's'] + elif from_tag == 'b': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag == 'm': + return to_tag in ['m', 'e'] and from_label==to_label + elif from_tag in ['e', 's']: + return to_tag in ['b', 's', 'end'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) + else: + raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) + class ConditionalRandomField(nn.Module): - def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): + def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): """ - :param tag_size: int, num of tags - :param include_start_end_trans: bool, whether to include start/end tag + + :param num_tags: int, 标签的数量。 + :param include_start_end_trans: bool, 是否包含起始tag + :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。 + 如果为None,则所有跃迁均为合法 + :param initial_method: """ + super(ConditionalRandomField, self).__init__() self.include_start_end_trans = include_start_end_trans - self.tag_size = tag_size + self.num_tags = num_tags # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score - self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) + self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) if self.include_start_end_trans: - self.start_scores = nn.Parameter(torch.randn(tag_size)) - self.end_scores = nn.Parameter(torch.randn(tag_size)) + self.start_scores = nn.Parameter(torch.randn(num_tags)) + self.end_scores = nn.Parameter(torch.randn(num_tags)) + + if allowed_transitions is None: + constrain = torch.zeros(num_tags + 2, num_tags + 2) + else: + constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 + for from_tag_id, to_tag_id in allowed_transitions: + constrain[from_tag_id, to_tag_id] = 0 + self._constrain = nn.Parameter(constrain, requires_grad=False) # self.reset_parameter() initial_parameter(self, initial_method) + def reset_parameter(self): nn.init.xavier_normal_(self.trans_m) if self.include_start_end_trans: @@ -49,7 +172,7 @@ class ConditionalRandomField(nn.Module): """ Computes the (batch_size,) denominator term for the log-likelihood, which is the sum of the likelihoods across all possible state sequences. - :param logits:FloatTensor, max_len x batch_size x tag_size + :param logits:FloatTensor, max_len x batch_size x num_tags :param mask:ByteTensor, max_len x batch_size :return:FloatTensor, batch_size """ @@ -72,7 +195,7 @@ class ConditionalRandomField(nn.Module): def _glod_score(self, logits, tags, mask): """ Compute the score for the gold path. - :param logits: FloatTensor, max_len x batch_size x tag_size + :param logits: FloatTensor, max_len x batch_size x num_tags :param tags: LongTensor, max_len x batch_size :param mask: ByteTensor, max_len x batch_size :return:FloatTensor, batch_size @@ -99,7 +222,7 @@ class ConditionalRandomField(nn.Module): def forward(self, feats, tags, mask): """ Calculate the neg log likelihood - :param feats:FloatTensor, batch_size x max_len x tag_size + :param feats:FloatTensor, batch_size x max_len x num_tags :param tags:LongTensor, batch_size x max_len :param mask:ByteTensor batch_size x max_len :return:FloatTensor, batch_size @@ -112,13 +235,20 @@ class ConditionalRandomField(nn.Module): return all_path_score - gold_path_score - def viterbi_decode(self, data, mask, get_score=False): + def viterbi_decode(self, data, mask, get_score=False, unpad=False): """ Given a feats matrix, return best decode path and best score. - :param data:FloatTensor, batch_size x max_len x tag_size + :param data:FloatTensor, batch_size x max_len x num_tags :param mask:ByteTensor batch_size x max_len :param get_score: bool, whether to output the decode score. - :return: scores, paths + :param unpad: bool, 是否将结果unpad, + 如果False, 返回的是batch_size x max_len的tensor, + 如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 + List[int]的长度是这个sample的有效长度 + :return: 如果get_score为False,返回结果根据unpadding变动 + 如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] + 为每个seqence的解码分数。 + """ batch_size, seq_len, n_tags = data.size() data = data.transpose(0, 1).data # L, B, H @@ -127,19 +257,23 @@ class ConditionalRandomField(nn.Module): # dp vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = data[0] + transitions = self._constrain.data.clone() + transitions[:n_tags, :n_tags] += self.trans_m.data if self.include_start_end_trans: - vscore += self.start_scores.view(1, -1) + transitions[n_tags, :n_tags] += self.start_scores.data + transitions[:n_tags, n_tags+1] += self.end_scores.data + + vscore += transitions[n_tags, :n_tags] + trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = data[i].view(batch_size, 1, n_tags) - trans_score = self.trans_m.view(1, n_tags, n_tags).data score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) - if self.include_start_end_trans: - vscore += self.end_scores.view(1, -1) + vscore += transitions[:n_tags, n_tags+1].view(1, -1) # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) @@ -154,7 +288,13 @@ class ConditionalRandomField(nn.Module): for i in range(seq_len - 1): last_tags = vpath[idxes[i], batch_idx, last_tags] ans[idxes[i+1], batch_idx] = last_tags - + ans = ans.transpose(0, 1) + if unpad: + paths = [] + for idx, seq_len in enumerate(lens): + paths.append(ans[idx, :seq_len+1].tolist()) + else: + paths = ans if get_score: - return ans_score, ans.transpose(0, 1) - return ans.transpose(0, 1) + return paths, ans_score.tolist() + return paths diff --git a/reproduction/chinese_word_segment/cws_io/cws_reader.py b/reproduction/chinese_word_segment/cws_io/cws_reader.py index 56a73351..34bcf7dd 100644 --- a/reproduction/chinese_word_segment/cws_io/cws_reader.py +++ b/reproduction/chinese_word_segment/cws_io/cws_reader.py @@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader def cut_long_sentence(sent, max_sample_length=200): + """ + 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length + + :param sent: str. + :param max_sample_length: int. + :return: list of str. + """ sent_no_space = sent.replace(' ', '') cutted_sentence = [] if len(sent_no_space) > max_sample_length: @@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader): return dataset -class ConlluCWSReader(object): - # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 +class ConllCWSReader(object): def __init__(self): pass def load(self, path, cut_long_sent=False): + """ + 返回的DataSet只包含raw_sentence这个field,内容为str。 + 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep + """ datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -150,10 +171,10 @@ class ConlluCWSReader(object): ds = DataSet() for sample in datalist: # print(sample) - res = self.get_one(sample) + res = self.get_char_lst(sample) if res is None: continue - line = ' '.join(res) + line = ' '.join(res) if cut_long_sent: sents = cut_long_sentence(line) else: @@ -163,7 +184,7 @@ class ConlluCWSReader(object): return ds - def get_one(self, sample): + def get_char_lst(self, sample): if len(sample)==0: return None text = [] diff --git a/reproduction/chinese_word_segment/models/cws_model.py b/reproduction/chinese_word_segment/models/cws_model.py index 4f81fea3..e31824e1 100644 --- a/reproduction/chinese_word_segment/models/cws_model.py +++ b/reproduction/chinese_word_segment/models/cws_model.py @@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask class CWSBiLSTMEncoder(BaseModel): def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, - hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): + hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1): super().__init__() self.input_size = 0 @@ -68,6 +68,7 @@ class CWSBiLSTMEncoder(BaseModel): if not bigrams is None: bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) + x_tensor = self.embedding_drop(x_tensor) sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) @@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel): from fastNLP.modules.decoder.CRF import ConditionalRandomField +from fastNLP.modules.decoder.CRF import allowed_transitions class CWSBiLSTMCRF(BaseModel): def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, - hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): + hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): + """ + 默认使用BMES的标注方式 + :param vocab_num: + :param embed_dim: + :param bigram_vocab_num: + :param bigram_embed_dim: + :param num_bigram_per_char: + :param hidden_size: + :param bidirectional: + :param embed_drop_p: + :param num_layers: + :param tag_size: + """ super(CWSBiLSTMCRF, self).__init__() self.tag_size = tag_size @@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel): size_layer = [hidden_size, 200, tag_size] self.decoder_model = MLP(size_layer) - self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) + allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') + self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, + allowed_transitions=allowed_trans) - def forward(self, chars, tags, seq_lens, bigrams=None): + def forward(self, chars, target, seq_lens, bigrams=None): device = self.parameters().__next__().device chars = chars.to(device).long() if not bigrams is None: @@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel): masks = seq_lens_to_mask(seq_lens) feats = self.encoder_model(chars, bigrams, seq_lens) feats = self.decoder_model(feats) - losses = self.crf(feats, tags, masks) + losses = self.crf(feats, target, masks) pred_dict = {} pred_dict['seq_lens'] = seq_lens @@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel): feats = self.decoder_model(feats) probs = self.crf.viterbi_decode(feats, masks, get_score=False) - return {'pred_tags': probs} + return {'pred': probs} diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index e7c069f1..d2c5d1d5 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -2,7 +2,6 @@ import re -from fastNLP.core.field import SeqLabelField from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.dataset import DataSet from fastNLP.api.processor import Processor @@ -11,7 +10,10 @@ from reproduction.chinese_word_segment.process.span_converter import SpanConvert _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' class SpeicalSpanProcessor(Processor): - # 这个类会将句子中的special span转换为对应的内容。 + """ + 将DataSet中field_name使用span_converter替换掉。 + + """ def __init__(self, field_name, new_added_field_name=None): super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) @@ -20,11 +22,12 @@ class SpeicalSpanProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: + def inner_proc(ins): sentence = ins[self.field_name] for span_converter in self.span_converters: sentence = span_converter.find_certain_span_and_replace(sentence) - ins[self.new_added_field_name] = sentence + return sentence + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) return dataset @@ -34,17 +37,22 @@ class SpeicalSpanProcessor(Processor): self.span_converters.append(converter) - class CWSCharSegProcessor(Processor): + """ + 将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学', + ' ', 'f', 'u', ...] + + """ def __init__(self, field_name, new_added_field_name): super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: + def inner_proc(ins): sentence = ins[self.field_name] chars = self._split_sent_into_chars(sentence) - ins[self.new_added_field_name] = chars + return chars + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) return dataset @@ -73,6 +81,10 @@ class CWSCharSegProcessor(Processor): class CWSTagProcessor(Processor): + """ + 为分词生成tag。该class为Base class。 + + """ def __init__(self, field_name, new_added_field_name=None): super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) @@ -107,18 +119,22 @@ class CWSTagProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: + def inner_proc(ins): sentence = ins[self.field_name] tag_list = self._generate_tag(sentence) - ins[self.new_added_field_name] = tag_list - dataset.set_target(**{self.new_added_field_name:True}) - dataset._set_need_tensor(**{self.new_added_field_name:True}) + return tag_list + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) + dataset.set_target(self.new_added_field_name) return dataset def _tags_from_word_len(self, word_len): raise NotImplementedError class CWSBMESTagProcessor(CWSTagProcessor): + """ + 通过DataSet中的field_name这个field生成相应的BMES的tag。 + + """ def __init__(self, field_name, new_added_field_name=None): super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) @@ -137,6 +153,10 @@ class CWSBMESTagProcessor(CWSTagProcessor): return tag_list class CWSSegAppTagProcessor(CWSTagProcessor): + """ + 通过DataSet中的field_name这个field生成相应的SegApp的tag。 + + """ def __init__(self, field_name, new_added_field_name=None): super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) @@ -151,6 +171,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor): class BigramProcessor(Processor): + """ + 这是生成bigram的基类。 + + """ def __init__(self, field_name, new_added_fielf_name=None): super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) @@ -158,22 +182,31 @@ class BigramProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: + def inner_proc(ins): characters = ins[self.field_name] bigrams = self._generate_bigram(characters) - ins[self.new_added_field_name] = bigrams + return bigrams + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) return dataset - def _generate_bigram(self, characters): pass class Pre2Post2BigramProcessor(BigramProcessor): - def __init__(self, field_name, new_added_fielf_name=None): + """ + 该bigram processor生成bigram的方式如下 + 原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为 + [L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....] + 即每个汉字,会有八个bigram, 对于上例中'a'的bigram为 + ['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac'] + 返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。 + + """ + def __init__(self, field_name, new_added_field_name=None): - super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) + super(BigramProcessor, self).__init__(field_name, new_added_field_name) def _generate_bigram(self, characters): bigrams = [] @@ -197,20 +230,102 @@ class Pre2Post2BigramProcessor(BigramProcessor): # 这里需要建立vocabulary了,但是遇到了以下的问题 # (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 # Processor了 +# TODO 如何将建立vocab和index这两步统一了? + +class VocabIndexerProcessor(Processor): + """ + 根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 + new_added_field_name, 则覆盖原有的field_name. + + """ + def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, + verbose=1): + """ + + :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 + :param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. + :param min_freq: 创建的Vocabulary允许的单词最少出现次数. + :param max_size: 创建的Vocabulary允许的最大的单词数量 + :param verbose: 0, 不输出任何信息;1,输出信息 + """ + super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) + self.min_freq = min_freq + self.max_size = max_size + + self.verbose =verbose + + def construct_vocab(self, *datasets): + """ + 使用传入的DataSet创建vocabulary + + :param datasets: DataSet类型的数据,用于构建vocabulary + :return: + """ + self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) + self.vocab.build_vocab() + if self.verbose: + print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) + + def process(self, *datasets, only_index_dataset=None): + """ + 若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary + 后,则会index datasets与only_index_dataset。 + + :param datasets: DataSet类型的数据 + :param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 + :return: + """ + if len(datasets)==0 and not hasattr(self,'vocab'): + raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") + if not hasattr(self, 'vocab'): + self.construct_vocab(*datasets) + else: + if self.verbose: + print("Using constructed vocabulary with {} items.".format(len(self.vocab))) + to_index_datasets = [] + if len(datasets)!=0: + for dataset in datasets: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + to_index_datasets.append(dataset) + + if not (only_index_dataset is None): + if isinstance(only_index_dataset, list): + for dataset in only_index_dataset: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + to_index_datasets.append(dataset) + elif isinstance(only_index_dataset, DataSet): + to_index_datasets.append(only_index_dataset) + else: + raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) + + for dataset in to_index_datasets: + assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) + dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], + new_field_name=self.new_added_field_name) + + def set_vocab(self, vocab): + assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) + self.vocab = vocab + + def delete_vocab(self): + del self.vocab + + def get_vocab_size(self): + return len(self.vocab) class VocabProcessor(Processor): - def __init__(self, field_name, min_count=1, max_vocab_size=None): + def __init__(self, field_name, min_freq=1, max_size=None): super(VocabProcessor, self).__init__(field_name, None) - self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) + self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) def process(self, *datasets): for dataset in datasets: assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - tokens = ins[self.field_name] - self.vocab.update(tokens) - + dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) def get_vocab(self): self.vocab.build_vocab() @@ -220,19 +335,6 @@ class VocabProcessor(Processor): return len(self.vocab) -class SeqLenProcessor(Processor): - def __init__(self, field_name, new_added_field_name='seq_lens'): - - super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) - - def process(self, dataset): - assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: - length = len(ins[self.field_name]) - ins[self.new_added_field_name] = length - dataset._set_need_tensor(**{self.new_added_field_name:True}) - return dataset - class SegApp2OutputProcessor(Processor): def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): super(SegApp2OutputProcessor, self).__init__(None, None) @@ -258,7 +360,32 @@ class SegApp2OutputProcessor(Processor): class BMES2OutputProcessor(Processor): - def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): + """ + 按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, + next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B + | | next_B | next_M | next_E | next_S | end | + |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| + | start | 合法 | next_M=B | next_E=S | 合法 | - | + | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | + | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | + | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | + | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | + 举例: + prediction为BSEMS,会被认为是SSSSS. + + """ + def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output', + b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): + """ + + :param chars_field_name: character所对应的field + :param tag_field_name: 预测对应的field + :param new_added_field_name: 转换后的内容所在field + :param b_idx: int, Begin标签所对应的tag idx. + :param m_idx: int, Middle标签所对应的tag idx. + :param e_idx: int, End标签所对应的tag idx. + :param s_idx: int, Single标签所对应的tag idx + """ super(BMES2OutputProcessor, self).__init__(None, None) self.chars_field_name = chars_field_name @@ -266,19 +393,55 @@ class BMES2OutputProcessor(Processor): self.new_added_field_name = new_added_field_name + self.b_idx = b_idx + self.m_idx = m_idx + self.e_idx = e_idx + self.s_idx = s_idx + # 还原init处介绍的矩阵 + self._valida_matrix = { + -1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx + self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], + self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], + self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], + self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], + } + + def _validate_tags(self, tags): + """ + 给定一个tag的List,返回合法tag + + :param tags: Tensor, shape: (seq_len, ) + :return: 返回修改为合法tag的list + """ + assert len(tags)!=0 + padded_tags = [-1, *tags, -1] + for idx in range(len(padded_tags)-1): + cur_tag = padded_tags[idx] + if cur_tag not in self._valida_matrix: + cur_tag = self.s_idx + if padded_tags[idx+1] not in self._valida_matrix: + padded_tags[idx+1] = self.s_idx + next_tag = padded_tags[idx+1] + shift_tag = self._valida_matrix[cur_tag][next_tag] + if shift_tag[0]!=-1: + padded_tags[idx+shift_tag[0]] = shift_tag[1] + + return padded_tags[1:-1] + def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) - for ins in dataset: + def inner_proc(ins): pred_tags = ins[self.tag_field_name] + pred_tags = self._validate_tags(pred_tags) chars = ins[self.chars_field_name] words = [] start_idx = 0 for idx, tag in enumerate(pred_tags): - if tag==3: - # 当前没有考虑将原文替换回去 + if tag==self.s_idx: words.extend(chars[start_idx:idx+1]) start_idx = idx + 1 - elif tag==2: + elif tag==self.e_idx: words.append(''.join(chars[start_idx:idx+1])) start_idx = idx + 1 - ins[self.new_added_field_name] = ' '.join(words) \ No newline at end of file + return ' '.join(words) + dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) \ No newline at end of file diff --git a/reproduction/pos_tag_model/pos_io/pos_reader.py b/reproduction/pos_tag_model/pos_io/pos_reader.py index 2ff07815..c0a8c4cd 100644 --- a/reproduction/pos_tag_model/pos_io/pos_reader.py +++ b/reproduction/pos_tag_model/pos_io/pos_reader.py @@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200): return cutted_sentence -class ConlluPOSReader(object): - # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 +class ConllPOSReader(object): + # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 def __init__(self): pass @@ -70,6 +70,70 @@ class ConlluPOSReader(object): return ds + + +class ZhConllPOSReader(object): + # 中文colln格式reader + def __init__(self): + pass + + def load(self, path): + """ + 返回的DataSet, 包含以下的field + words:list of str, + tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] + 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split('\t')) + if len(sample) > 0: + datalist.append(sample) + + ds = DataSet() + for sample in datalist: + # print(sample) + res = self.get_one(sample) + if res is None: + continue + char_seq = [] + pos_seq = [] + for word, tag in zip(res[0], res[1]): + char_seq.extend(list(word)) + if len(word)==1: + pos_seq.append('S-{}'.format(tag)) + elif len(word)>1: + pos_seq.append('B-{}'.format(tag)) + for _ in range(len(word)-2): + pos_seq.append('M-{}'.format(tag)) + pos_seq.append('E-{}'.format(tag)) + else: + raise ValueError("Zero length of word detected.") + + ds.append(Instance(words=char_seq, + tag=pos_seq)) + + return ds + def get_one(self, sample): if len(sample)==0: return None @@ -84,6 +148,6 @@ class ConlluPOSReader(object): return text, pos_tags if __name__ == '__main__': - reader = ConlluPOSReader() + reader = ZhConllPOSReader() d = reader.load('/home/hyan/train.conllx') - print('reader') \ No newline at end of file + print(d) \ No newline at end of file diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 125b9156..1dbab314 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -4,6 +4,7 @@ import numpy as np import torch from fastNLP.core.metrics import AccuracyMetric +from fastNLP.core.metrics import BMESF1PreRecMetric from fastNLP.core.metrics import pred_topk, accuracy_topk @@ -132,6 +133,234 @@ class TestAccuracyMetric(unittest.TestCase): return self.assertTrue(True, False), "No exception catches." +class SpanF1PreRecMetric(unittest.TestCase): + def test_case1(self): + from fastNLP.core.metrics import bmes_tag_to_spans + from fastNLP.core.metrics import bio_tag_to_spans + + bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] + bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] + expect_bmes_res = set() + expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), + ('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) + expect_bio_res = set() + expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), + ('6', (4, 4)), ('7', (6, 6))]) + self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) + self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) + # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 + # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans + # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans + # for i in range(1000): + # strs = list(map(str, np.random.randint(100, size=1000))) + # bmes = list('bmes'.upper()) + # bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] + # bio = list('bio'.upper()) + # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] + # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) + # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) + + def test_case2(self): + # 测试不带label的 + from fastNLP.core.metrics import bmes_tag_to_spans + from fastNLP.core.metrics import bio_tag_to_spans + + bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] + bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] + expect_bmes_res = set() + expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) + expect_bio_res = set() + expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) + self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) + self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) + # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 + # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans + # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans + # for i in range(1000): + # bmes = list('bmes'.upper()) + # bmes_strs = np.random.choice(bmes, size=1000) + # bio = list('bio'.upper()) + # bio_strs = np.random.choice(bio, size=100) + # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) + # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) + + def tese_case3(self): + from fastNLP.core.vocabulary import Vocabulary + from collections import Counter + from fastNLP.core.metrics import SpanFPreRecMetric + # 与allennlp测试能否正确计算f metric + # + def generate_allen_tags(encoding_type, number_labels=4): + vocab = {} + for i in range(number_labels): + label = str(i) + for tag in encoding_type: + if tag == 'O': + if tag not in vocab: + vocab['O'] = len(vocab) + 1 + continue + vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count + return vocab + + number_labels = 4 + # bio tag + fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) + fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) + fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) + bio_sequence = torch.FloatTensor( + [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, + 0.0470, 0.0971], + [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, + 0.7987, -0.3970], + [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, + 0.6880, 1.4348], + [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, + -1.6876, -0.8917], + [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, + 1.4217, 0.2622]], + + [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, + 1.3592, -0.8973], + [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, + -0.4025, -0.3417], + [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, + 0.2861, -0.3966], + [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, + 0.0213, 1.4777], + [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, + 1.3024, 0.2001]]] + ) + bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], + [5., 6., 8., 6., 0.]]) + fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) + expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, + 'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, + 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, + 'f': 0.12499999999994846} + self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) + + #bmes tag + bmes_sequence = torch.FloatTensor( + [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, + -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, + -0.3505, -0.6002], + [0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, + 1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, + 0.4439, 0.2514], + [-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, + -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, + 0.5802, -0.0516], + [-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, + 4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, + -0.9242, -0.0333], + [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, + 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, + -0.3779, -0.3195]], + + [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, + 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, + -0.1103, 0.4417], + [-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, + -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, + -0.6386, 0.2797], + [0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, + 0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, + 1.1237, -1.5385], + [0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, + -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, + -0.0591, -0.2461], + [-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, + 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, + -0.7344, -1.2046]]] + ) + bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], + [ 6., 15., 6., 15., 5.]]) + + fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) + fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) + fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') + fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) + + expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, + 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, + 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, + 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, + 'pre': 0.499999999999995, 'rec': 0.499999999999995} + + self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) + + # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 + # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary + # from allennlp.training.metrics import SpanBasedF1Measure + # allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, + # non_padded_namespaces=['tags']) + # allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') + # bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) + # bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) + # allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) + # fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) + # fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) + # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) + # + # def convert_allen_res_to_fastnlp_res(metric_result): + # allen_result = {} + # key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} + # for key, value in metric_result.items(): + # if key in key_map: + # key = key_map[key] + # else: + # label = key.split('-')[-1] + # if key.startswith('f1'): + # key = 'f-{}'.format(label) + # else: + # key = '{}-{}'.format(key[:3], label) + # allen_result[key] = value + # return allen_result + # + # # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) + # # print(fastnlp_bio_metric.get_metric()) + # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), + # fastnlp_bio_metric.get_metric()) + # + # allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) + # allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') + # fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) + # fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) + # fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') + # bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) + # bmes_target = torch.randint(4 * number_labels, size=(2, 20)) + # allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) + # fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) + # + # # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) + # # print(fastnlp_bmes_metric.get_metric()) + # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), + # fastnlp_bmes_metric.get_metric()) + +class TestBMESF1PreRecMetric(unittest.TestCase): + def test_case1(self): + seq_lens = torch.LongTensor([4, 2]) + pred = torch.randn(2, 4, 4) + target = torch.LongTensor([[0, 1, 2, 3], + [3, 3, 0, 0]]) + pred_dict = {'pred': pred} + target_dict = {'target': target, 'seq_lens': seq_lens} + + metric = BMESF1PreRecMetric() + metric(pred_dict, target_dict) + metric.get_metric() + + def test_case2(self): + # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} + seq_lens = torch.LongTensor([4, 2]) + target = torch.LongTensor([[0, 1, 2, 3], + [3, 3, 0, 0]]) + pred_dict = {'pred': target} + target_dict = {'target': target, 'seq_lens': seq_lens} + + metric = BMESF1PreRecMetric() + metric(pred_dict, target_dict) + self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) class TestUsefulFunctions(unittest.TestCase): # 测试metrics.py中一些看上去挺有用的函数 diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py new file mode 100644 index 00000000..0fc331dc --- /dev/null +++ b/test/modules/decoder/test_CRF.py @@ -0,0 +1,104 @@ + +import unittest + + +class TestCRF(unittest.TestCase): + def test_case1(self): + # 检查allowed_transitions()能否正确使用 + from fastNLP.modules.decoder.CRF import allowed_transitions + + id2label = {0: 'B', 1: 'I', 2:'O'} + expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), + (2, 4), (3, 0), (3, 2)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) + + id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) + + id2label = {0: 'B', 1: 'I', 2:'O', 3: '', 4:""} + allowed_transitions(id2label) + + labels = ['O'] + for label in ['X', 'Y']: + for tag in 'BI': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx:label for idx, label in enumerate(labels)} + expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), + (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), + (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) + + labels = [] + for label in ['X', 'Y']: + for tag in 'BMES': + labels.append('{}-{}'.format(tag, label)) + id2label = {idx: label for idx, label in enumerate(labels)} + expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), + (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), + (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) + + def test_case2(self): + # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 + pass + # import torch + # from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask + # + # labels = ['O'] + # for label in ['X', 'Y']: + # for tag in 'BI': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), + # include_start_end_transitions=False) + # batch_size = 3 + # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() + # trans_m = allen_CRF.transitions + # seq_lens = torch.randint(1, 20, size=(batch_size,)) + # seq_lens[-1] = 20 + # mask = seq_len_to_byte_mask(seq_lens) + # allen_res = allen_CRF.viterbi_tags(logits, mask) + # + # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) + # fast_CRF.trans_m = trans_m + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # # score equal + # self.assertListEqual([score for _, score in allen_res], fast_res[1]) + # # seq equal + # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) + # + # + # labels = [] + # for label in ['X', 'Y']: + # for tag in 'BMES': + # labels.append('{}-{}'.format(tag, label)) + # id2label = {idx: label for idx, label in enumerate(labels)} + # num_tags = len(id2label) + # + # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions + # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), + # include_start_end_transitions=False) + # batch_size = 3 + # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() + # trans_m = allen_CRF.transitions + # seq_lens = torch.randint(1, 20, size=(batch_size,)) + # seq_lens[-1] = 20 + # mask = seq_len_to_byte_mask(seq_lens) + # allen_res = allen_CRF.viterbi_tags(logits, mask) + # + # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions + # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, + # encoding_type='BMES')) + # fast_CRF.trans_m = trans_m + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # # score equal + # self.assertListEqual([score for _, score in allen_res], fast_res[1]) + # # seq equal + # self.assertListEqual([_ for _, score in allen_res], fast_res[0]) + +