2. metric增加SpanFMetric,可以用于计算sequence labelling的performance 3. 分词复现任务根据新版接口做了部分调整。tags/v0.3.0^2
@@ -9,8 +9,8 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.model_zoo import load_url | from fastNLP.api.model_zoo import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
@@ -95,7 +95,7 @@ class POS(API): | |||||
pipeline.append(tag_proc) | pipeline.append(tag_proc) | ||||
pp = Pipeline(pipeline) | pp = Pipeline(pipeline) | ||||
reader = ConlluPOSReader() | |||||
reader = ConllPOSReader() | |||||
te_dataset = reader.load(filepath) | te_dataset = reader.load(filepath) | ||||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | evaluator = SeqLabelEvaluator2('word_seq_origin_len') | ||||
@@ -168,7 +168,7 @@ class CWS(API): | |||||
pipeline.insert(1, tag_proc) | pipeline.insert(1, tag_proc) | ||||
pp = Pipeline(pipeline) | pp = Pipeline(pipeline) | ||||
reader = ConlluCWSReader() | |||||
reader = ConllCWSReader() | |||||
# te_filename = '/home/hyan/ctb3/test.conllx' | # te_filename = '/home/hyan/ctb3/test.conllx' | ||||
te_dataset = reader.load(filepath) | te_dataset = reader.load(filepath) | ||||
@@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor(object): | class Processor(object): | ||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
""" | |||||
:param field_name: 处理哪个field | |||||
:param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field | |||||
""" | |||||
self.field_name = field_name | self.field_name = field_name | ||||
if new_added_field_name is None: | if new_added_field_name is None: | ||||
self.new_added_field_name = field_name | self.new_added_field_name = field_name | ||||
@@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
class PreAppendProcessor(Processor): | class PreAppendProcessor(Processor): | ||||
""" | |||||
向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 | |||||
[data] + instance[field_name] | |||||
""" | |||||
def __init__(self, data, field_name, new_added_field_name=None): | def __init__(self, data, field_name, new_added_field_name=None): | ||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.data = data | self.data = data | ||||
@@ -102,6 +112,10 @@ class PreAppendProcessor(Processor): | |||||
class SliceProcessor(Processor): | class SliceProcessor(Processor): | ||||
""" | |||||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||||
""" | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | def __init__(self, start, end, step, field_name, new_added_field_name=None): | ||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | super(SliceProcessor, self).__init__(field_name, new_added_field_name) | ||||
for o in (start, end, step): | for o in (start, end, step): | ||||
@@ -114,7 +128,17 @@ class SliceProcessor(Processor): | |||||
class Num2TagProcessor(Processor): | class Num2TagProcessor(Processor): | ||||
""" | |||||
将一句话中的数字转换为某个tag。 | |||||
""" | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | def __init__(self, tag, field_name, new_added_field_name=None): | ||||
""" | |||||
:param tag: str, 将数字转换为该tag | |||||
:param field_name: | |||||
:param new_added_field_name: | |||||
""" | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.tag = tag | self.tag = tag | ||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | ||||
@@ -135,6 +159,10 @@ class Num2TagProcessor(Processor): | |||||
class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
""" | |||||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||||
['我', '是', xxx] | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -163,19 +191,19 @@ class IndexerProcessor(Processor): | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
"""Build vocabulary with a field in the data set. | |||||
""" | |||||
传入若干个DataSet以建立vocabulary。 | |||||
""" | """ | ||||
def __init__(self, field_name): | |||||
def __init__(self, field_name, min_freq=1, max_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary() | |||||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||||
def process(self, *datasets): | def process(self, *datasets): | ||||
for dataset in datasets: | for dataset in datasets: | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
self.vocab.update(ins[self.field_name]) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
@@ -183,6 +211,10 @@ class VocabProcessor(Processor): | |||||
class SeqLenProcessor(Processor): | class SeqLenProcessor(Processor): | ||||
""" | |||||
根据某个field新增一个sequence length的field。取该field的第一维 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | ||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.is_input = is_input | self.is_input = is_input | ||||
@@ -195,10 +227,15 @@ class SeqLenProcessor(Processor): | |||||
return dataset | return dataset | ||||
from fastNLP.core.utils import _build_args | |||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
迭代模型并将结果的padding drop掉 | |||||
传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. | |||||
model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 | |||||
(Batch_size, 1),则使用seqence length这个field进行unpad | |||||
TODO 这个类需要删除对seq_lens的依赖。 | |||||
:param seq_len_field_name: | :param seq_len_field_name: | ||||
:param batch_size: | :param batch_size: | ||||
@@ -211,13 +248,18 @@ class ModelProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
self.model.eval() | self.model.eval() | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | |||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
if hasattr(self.model, "predict"): | |||||
predict_func = self.model.predict | |||||
else: | |||||
predict_func = self.model.forward | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
prediction = self.model.predict(**batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||||
refined_batch_x = _build_args(predict_func, **batch_x) | |||||
prediction = predict_func(**refined_batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].tolist() | |||||
for key, value in prediction.items(): | for key, value in prediction.items(): | ||||
tmp_batch = [] | tmp_batch = [] | ||||
@@ -246,6 +288,10 @@ class ModelProcessor(Processor): | |||||
class Index2WordProcessor(Processor): | class Index2WordProcessor(Processor): | ||||
""" | |||||
将DataSet中某个为index的field根据vocab转换为str | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
@@ -266,5 +312,5 @@ class SetIsTargetProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | ||||
set_dict.update(self.field_dict) | set_dict.update(self.field_dict) | ||||
dataset.set_target(**set_dict) | |||||
dataset.set_target(*set_dict.keys()) | |||||
return dataset | return dataset |
@@ -254,7 +254,7 @@ class DataSet(object): | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
""" | """ | ||||
results = [func(ins) for ins in self._inner_iter()] | results = [func(ins) for ins in self._inner_iter()] | ||||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | raise ValueError("{} always return None.".format(get_func_signature(func=func))) | ||||
extra_param = {} | extra_param = {} | ||||
@@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class MetricBase(object): | class MetricBase(object): | ||||
@@ -62,11 +63,6 @@ class MetricBase(object): | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | |||||
# if func_spect.varargs: | |||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
# f"positional argument.).") | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
@@ -91,10 +87,9 @@ class MetricBase(object): | |||||
This method will call self.evaluate method. | This method will call self.evaluate method. | ||||
Before calling self.evaluate, it will first check the validity of output_dict, target_dict | Before calling self.evaluate, it will first check the validity of output_dict, target_dict | ||||
(1) whether self.evaluate has varargs, which is not supported. | |||||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
(1) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
@@ -275,6 +270,369 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
return evaluate_result | return evaluate_result | ||||
def bmes_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1])) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def bio_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bio_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bio_tag, label = tag[:1], tag[2:] | |||||
if bio_tag == 'b': | |||||
spans.append((label, [idx, idx])) | |||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bio_tag == 'o': # o tag does not count | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bio_tag = bio_tag | |||||
return [(span[0], (span[1][0], span[1][1])) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
class SpanFPreRecMetric(MetricBase): | |||||
""" | |||||
在序列标注问题中,以span的方式计算F, pre, rec. | |||||
最后得到的metric结果为 | |||||
{ | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值 | |||||
{ | |||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
""" | |||||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||||
only_gross=True, f_type='micro', beta=1): | |||||
""" | |||||
:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||||
:param encoding_type: str, 目前支持bio, bmes | |||||
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||||
个label | |||||
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||||
label的f1, pre, rec | |||||
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | |||||
encoding_type = encoding_type.lower() | |||||
if encoding_type not in ('bio', 'bmes'): | |||||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||||
if not isinstance(tag_vocab, Vocabulary): | |||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||||
if f_type not in ('micro', 'macro'): | |||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||||
self.encoding_type = encoding_type | |||||
if self.encoding_type == 'bmes': | |||||
self.tag_to_span_func = bmes_tag_to_spans | |||||
elif self.encoding_type == 'bio': | |||||
self.tag_to_span_func = bio_tag_to_spans | |||||
self.ignore_labels = ignore_labels | |||||
self.f_type = f_type | |||||
self.beta = beta | |||||
self.beta_square = self.beta**2 | |||||
self.only_gross = only_gross | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.tag_vocab = tag_vocab | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
def evaluate(self, pred, target, seq_lens): | |||||
""" | |||||
A lot of design idea comes from allennlp's measure | |||||
:param pred: | |||||
:param target: | |||||
:param seq_lens: | |||||
:return: | |||||
""" | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
num_classes = pred.size(-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
batch_size = pred.size(0) | |||||
for i in range(batch_size): | |||||
pred_tags = pred[i, :seq_lens[i]].tolist() | |||||
gold_tags = target[i, :seq_lens[i]].tolist() | |||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||||
for span in pred_spans: | |||||
if span in gold_spans: | |||||
self._true_positives[span[0]] += 1 | |||||
gold_spans.remove(span) | |||||
else: | |||||
self._false_positives[span[0]] += 1 | |||||
for span in gold_spans: | |||||
self._false_negatives[span[0]] += 1 | |||||
def get_metric(self, reset=True): | |||||
evaluate_result = {} | |||||
if not self.only_gross or self.f_type=='macro': | |||||
tags = set(self._false_negatives.keys()) | |||||
tags.update(set(self._false_positives.keys())) | |||||
tags.update(set(self._true_positives.keys())) | |||||
f_sum = 0 | |||||
pre_sum = 0 | |||||
rec_sum = 0 | |||||
for tag in tags: | |||||
tp = self._true_positives[tag] | |||||
fn = self._false_negatives[tag] | |||||
fp = self._false_positives[tag] | |||||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||||
f_sum += f | |||||
pre_sum += pre | |||||
rec_sum + rec | |||||
if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | |||||
pre_key = 'pre-{}'.format(tag) | |||||
rec_key = 'rec-{}'.format(tag) | |||||
evaluate_result[f_key] = f | |||||
evaluate_result[pre_key] = pre | |||||
evaluate_result[rec_key] = rec | |||||
if self.f_type == 'macro': | |||||
evaluate_result['f'] = f_sum/len(tags) | |||||
evaluate_result['pre'] = pre_sum/len(tags) | |||||
evaluate_result['rec'] = rec_sum/len(tags) | |||||
if self.f_type == 'micro': | |||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||||
sum(self._false_negatives.values()), | |||||
sum(self._false_positives.values())) | |||||
evaluate_result['f'] = f | |||||
evaluate_result['pre'] = pre | |||||
evaluate_result['rec'] = rec | |||||
if reset: | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
return evaluate_result | |||||
def _compute_f_pre_rec(self, tp, fn, fp): | |||||
""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
class BMESF1PreRecMetric(MetricBase): | |||||
""" | |||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||||
| | next_B | next_M | next_E | next_S | end | | |||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
举例: | |||||
prediction为BSEMS,会被认为是SSSSS. | |||||
本Metric不检验target的合法性,请务必保证target的合法性。 | |||||
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 | |||||
target形状为 (batch_size, max_len) | |||||
seq_lens形状为 (batch_size, ) | |||||
""" | |||||
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): | |||||
""" | |||||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 | |||||
""" | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.yt_wordnum = 0 | |||||
self.yp_wordnum = 0 | |||||
self.corr_num = 0 | |||||
self.b_idx = b_idx | |||||
self.m_idx = m_idx | |||||
self.e_idx = e_idx | |||||
self.s_idx = s_idx | |||||
# 还原init处介绍的矩阵 | |||||
self._valida_matrix = { | |||||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
} | |||||
def _validate_tags(self, tags): | |||||
""" | |||||
给定一个tag的Tensor,返回合法tag | |||||
:param tags: Tensor, shape: (seq_len, ) | |||||
:return: 返回修改为合法tag的list | |||||
""" | |||||
assert len(tags)!=0 | |||||
assert isinstance(tags, torch.Tensor) and len(tags.size())==1 | |||||
padded_tags = [-1, *tags.tolist(), -1] | |||||
for idx in range(len(padded_tags)-1): | |||||
cur_tag = padded_tags[idx] | |||||
if cur_tag not in self._valida_matrix: | |||||
cur_tag = self.s_idx | |||||
if padded_tags[idx+1] not in self._valida_matrix: | |||||
padded_tags[idx+1] = self.s_idx | |||||
next_tag = padded_tags[idx+1] | |||||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||||
if shift_tag[0]!=-1: | |||||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||||
return padded_tags[1:-1] | |||||
def evaluate(self, pred, target, seq_lens): | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
for idx in range(len(pred)): | |||||
seq_len = seq_lens[idx] | |||||
target_tags = target[idx][:seq_len].tolist() | |||||
pred_tags = pred[idx][:seq_len] | |||||
pred_tags = self._validate_tags(pred_tags) | |||||
start_idx = 0 | |||||
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): | |||||
if t_tag in (self.s_idx, self.e_idx): | |||||
self.yt_wordnum += 1 | |||||
corr_flag = True | |||||
for i in range(start_idx, t_idx+1): | |||||
if target_tags[i]!=pred_tags[i]: | |||||
corr_flag = False | |||||
if corr_flag: | |||||
self.corr_num += 1 | |||||
start_idx = t_idx + 1 | |||||
if p_tag in (self.s_idx, self.e_idx): | |||||
self.yp_wordnum += 1 | |||||
def get_metric(self, reset=True): | |||||
P = self.corr_num / (self.yp_wordnum + 1e-12) | |||||
R = self.corr_num / (self.yt_wordnum + 1e-12) | |||||
F = 2 * P * R / (P + R + 1e-12) | |||||
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} | |||||
if reset: | |||||
self.yp_wordnum = 0 | |||||
self.yt_wordnum = 0 | |||||
self.corr_num = 0 | |||||
return evaluate_result | |||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -31,9 +31,8 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -19,26 +19,149 @@ def seq_len_to_byte_mask(seq_lens): | |||||
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) | mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) | ||||
return mask | return mask | ||||
def allowed_transitions(id2label, encoding_type='bio'): | |||||
""" | |||||
:param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||||
:param encoding_type: str, 支持"bio", "bmes"。 | |||||
:return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
""" | |||||
num_tags = len(id2label) | |||||
start_idx = num_tags | |||||
end_idx = num_tags + 1 | |||||
encoding_type = encoding_type.lower() | |||||
allowed_trans = [] | |||||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||||
def split_tag_label(from_label): | |||||
from_label = from_label.lower() | |||||
if from_label in ['start', 'end']: | |||||
from_tag = from_label | |||||
from_label = '' | |||||
else: | |||||
from_tag = from_label[:1] | |||||
from_label = from_label[2:] | |||||
return from_tag, from_label | |||||
for from_id, from_label in id_label_lst: | |||||
if from_label in ['<pad>', '<unk>']: | |||||
continue | |||||
from_tag, from_label = split_tag_label(from_label) | |||||
for to_id, to_label in id_label_lst: | |||||
if to_label in ['<pad>', '<unk>']: | |||||
continue | |||||
to_tag, to_label = split_tag_label(to_label) | |||||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
allowed_trans.append((from_id, to_id)) | |||||
return allowed_trans | |||||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
""" | |||||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param from_label: str, 比如"PER", "LOC"等label | |||||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param to_label: str, 比如"PER", "LOC"等label | |||||
:return: bool,能否跃迁 | |||||
""" | |||||
if to_tag=='start' or from_tag=='end': | |||||
return False | |||||
encoding_type = encoding_type.lower() | |||||
if encoding_type == 'bio': | |||||
""" | |||||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+-------+-----+ | |||||
| | B | I | O | start | end | | |||||
+-------+---+---+---+-------+-----+ | |||||
| B | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| I | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| O | y | n | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| start | y | n | y | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ('b', 'o') | |||||
elif from_tag in ['b', 'i']: | |||||
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||||
elif from_tag == 'o': | |||||
return to_tag in ['end', 'b', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||||
elif encoding_type == 'bmes': | |||||
""" | |||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| | B | M | E | S | start | end | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| B | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| M | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| E | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| S | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| start | y | n | n | y | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag in ['e', 's']: | |||||
return to_tag in ['b', 's', 'end'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||||
else: | |||||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||||
""" | """ | ||||
:param tag_size: int, num of tags | |||||
:param include_start_end_trans: bool, whether to include start/end tag | |||||
:param num_tags: int, 标签的数量。 | |||||
:param include_start_end_trans: bool, 是否包含起始tag | |||||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。 | |||||
如果为None,则所有跃迁均为合法 | |||||
:param initial_method: | |||||
""" | """ | ||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
self.tag_size = tag_size | |||||
self.num_tags = num_tags | |||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(tag_size)) | |||||
self.end_scores = nn.Parameter(torch.randn(tag_size)) | |||||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||||
if allowed_transitions is None: | |||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||||
else: | |||||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||||
for from_tag_id, to_tag_id in allowed_transitions: | |||||
constrain[from_tag_id, to_tag_id] = 0 | |||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||||
# self.reset_parameter() | # self.reset_parameter() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | def reset_parameter(self): | ||||
nn.init.xavier_normal_(self.trans_m) | nn.init.xavier_normal_(self.trans_m) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -49,7 +172,7 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param logits:FloatTensor, max_len x batch_size x tag_size | |||||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||||
:param mask:ByteTensor, max_len x batch_size | :param mask:ByteTensor, max_len x batch_size | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
@@ -72,7 +195,7 @@ class ConditionalRandomField(nn.Module): | |||||
def _glod_score(self, logits, tags, mask): | def _glod_score(self, logits, tags, mask): | ||||
""" | """ | ||||
Compute the score for the gold path. | Compute the score for the gold path. | ||||
:param logits: FloatTensor, max_len x batch_size x tag_size | |||||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||||
:param tags: LongTensor, max_len x batch_size | :param tags: LongTensor, max_len x batch_size | ||||
:param mask: ByteTensor, max_len x batch_size | :param mask: ByteTensor, max_len x batch_size | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -99,7 +222,7 @@ class ConditionalRandomField(nn.Module): | |||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
""" | """ | ||||
Calculate the neg log likelihood | Calculate the neg log likelihood | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||||
:param tags:LongTensor, batch_size x max_len | :param tags:LongTensor, batch_size x max_len | ||||
:param mask:ByteTensor batch_size x max_len | :param mask:ByteTensor batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -112,13 +235,20 @@ class ConditionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, data, mask, get_score=False): | |||||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||||
""" | """ | ||||
Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
:param data:FloatTensor, batch_size x max_len x tag_size | |||||
:param data:FloatTensor, batch_size x max_len x num_tags | |||||
:param mask:ByteTensor batch_size x max_len | :param mask:ByteTensor batch_size x max_len | ||||
:param get_score: bool, whether to output the decode score. | :param get_score: bool, whether to output the decode score. | ||||
:return: scores, paths | |||||
:param unpad: bool, 是否将结果unpad, | |||||
如果False, 返回的是batch_size x max_len的tensor, | |||||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||||
List[int]的长度是这个sample的有效长度 | |||||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||||
为每个seqence的解码分数。 | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = data.size() | batch_size, seq_len, n_tags = data.size() | ||||
data = data.transpose(0, 1).data # L, B, H | data = data.transpose(0, 1).data # L, B, H | ||||
@@ -127,19 +257,23 @@ class ConditionalRandomField(nn.Module): | |||||
# dp | # dp | ||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = data[0] | vscore = data[0] | ||||
transitions = self._constrain.data.clone() | |||||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
vscore += self.start_scores.view(1, -1) | |||||
transitions[n_tags, :n_tags] += self.start_scores.data | |||||
transitions[:n_tags, n_tags+1] += self.end_scores.data | |||||
vscore += transitions[n_tags, :n_tags] | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
cur_score = data[i].view(batch_size, 1, n_tags) | cur_score = data[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | ||||
if self.include_start_end_trans: | |||||
vscore += self.end_scores.view(1, -1) | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | ||||
@@ -154,7 +288,13 @@ class ConditionalRandomField(nn.Module): | |||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | last_tags = vpath[idxes[i], batch_idx, last_tags] | ||||
ans[idxes[i+1], batch_idx] = last_tags | ans[idxes[i+1], batch_idx] = last_tags | ||||
ans = ans.transpose(0, 1) | |||||
if unpad: | |||||
paths = [] | |||||
for idx, seq_len in enumerate(lens): | |||||
paths.append(ans[idx, :seq_len+1].tolist()) | |||||
else: | |||||
paths = ans | |||||
if get_score: | if get_score: | ||||
return ans_score, ans.transpose(0, 1) | |||||
return ans.transpose(0, 1) | |||||
return paths, ans_score.tolist() | |||||
return paths |
@@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | def cut_long_sentence(sent, max_sample_length=200): | ||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | sent_no_space = sent.replace(' ', '') | ||||
cutted_sentence = [] | cutted_sentence = [] | ||||
if len(sent_no_space) > max_sample_length: | if len(sent_no_space) > max_sample_length: | ||||
@@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader): | |||||
return dataset | return dataset | ||||
class ConlluCWSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
class ConllCWSReader(object): | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
def load(self, path, cut_long_sent=False): | def load(self, path, cut_long_sent=False): | ||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -150,10 +171,10 @@ class ConlluCWSReader(object): | |||||
ds = DataSet() | ds = DataSet() | ||||
for sample in datalist: | for sample in datalist: | ||||
# print(sample) | # print(sample) | ||||
res = self.get_one(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | if res is None: | ||||
continue | continue | ||||
line = ' '.join(res) | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | if cut_long_sent: | ||||
sents = cut_long_sentence(line) | sents = cut_long_sentence(line) | ||||
else: | else: | ||||
@@ -163,7 +184,7 @@ class ConlluCWSReader(object): | |||||
return ds | return ds | ||||
def get_one(self, sample): | |||||
def get_char_lst(self, sample): | |||||
if len(sample)==0: | if len(sample)==0: | ||||
return None | return None | ||||
text = [] | text = [] | ||||
@@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | class CWSBiLSTMEncoder(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1): | |||||
super().__init__() | super().__init__() | ||||
self.input_size = 0 | self.input_size = 0 | ||||
@@ -68,6 +68,7 @@ class CWSBiLSTMEncoder(BaseModel): | |||||
if not bigrams is None: | if not bigrams is None: | ||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | ||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | ||||
x_tensor = self.embedding_drop(x_tensor) | |||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | ||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | ||||
@@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | from fastNLP.modules.decoder.CRF import ConditionalRandomField | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
class CWSBiLSTMCRF(BaseModel): | class CWSBiLSTMCRF(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): | |||||
""" | |||||
默认使用BMES的标注方式 | |||||
:param vocab_num: | |||||
:param embed_dim: | |||||
:param bigram_vocab_num: | |||||
:param bigram_embed_dim: | |||||
:param num_bigram_per_char: | |||||
:param hidden_size: | |||||
:param bidirectional: | |||||
:param embed_drop_p: | |||||
:param num_layers: | |||||
:param tag_size: | |||||
""" | |||||
super(CWSBiLSTMCRF, self).__init__() | super(CWSBiLSTMCRF, self).__init__() | ||||
self.tag_size = tag_size | self.tag_size = tag_size | ||||
@@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel): | |||||
size_layer = [hidden_size, 200, tag_size] | size_layer = [hidden_size, 200, tag_size] | ||||
self.decoder_model = MLP(size_layer) | self.decoder_model = MLP(size_layer) | ||||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||||
allowed_transitions=allowed_trans) | |||||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||||
def forward(self, chars, target, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | device = self.parameters().__next__().device | ||||
chars = chars.to(device).long() | chars = chars.to(device).long() | ||||
if not bigrams is None: | if not bigrams is None: | ||||
@@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel): | |||||
masks = seq_lens_to_mask(seq_lens) | masks = seq_lens_to_mask(seq_lens) | ||||
feats = self.encoder_model(chars, bigrams, seq_lens) | feats = self.encoder_model(chars, bigrams, seq_lens) | ||||
feats = self.decoder_model(feats) | feats = self.decoder_model(feats) | ||||
losses = self.crf(feats, tags, masks) | |||||
losses = self.crf(feats, target, masks) | |||||
pred_dict = {} | pred_dict = {} | ||||
pred_dict['seq_lens'] = seq_lens | pred_dict['seq_lens'] = seq_lens | ||||
@@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel): | |||||
feats = self.decoder_model(feats) | feats = self.decoder_model(feats) | ||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | probs = self.crf.viterbi_decode(feats, masks, get_score=False) | ||||
return {'pred_tags': probs} | |||||
return {'pred': probs} | |||||
@@ -2,7 +2,6 @@ | |||||
import re | import re | ||||
from fastNLP.core.field import SeqLabelField | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
@@ -11,7 +10,10 @@ from reproduction.chinese_word_segment.process.span_converter import SpanConvert | |||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | ||||
class SpeicalSpanProcessor(Processor): | class SpeicalSpanProcessor(Processor): | ||||
# 这个类会将句子中的special span转换为对应的内容。 | |||||
""" | |||||
将DataSet中field_name使用span_converter替换掉。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -20,11 +22,12 @@ class SpeicalSpanProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
for span_converter in self.span_converters: | for span_converter in self.span_converters: | ||||
sentence = span_converter.find_certain_span_and_replace(sentence) | sentence = span_converter.find_certain_span_and_replace(sentence) | ||||
ins[self.new_added_field_name] = sentence | |||||
return sentence | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
@@ -34,17 +37,22 @@ class SpeicalSpanProcessor(Processor): | |||||
self.span_converters.append(converter) | self.span_converters.append(converter) | ||||
class CWSCharSegProcessor(Processor): | class CWSCharSegProcessor(Processor): | ||||
""" | |||||
将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学', | |||||
' ', 'f', 'u', ...] | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
chars = self._split_sent_into_chars(sentence) | chars = self._split_sent_into_chars(sentence) | ||||
ins[self.new_added_field_name] = chars | |||||
return chars | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
@@ -73,6 +81,10 @@ class CWSCharSegProcessor(Processor): | |||||
class CWSTagProcessor(Processor): | class CWSTagProcessor(Processor): | ||||
""" | |||||
为分词生成tag。该class为Base class。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -107,18 +119,22 @@ class CWSTagProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
tag_list = self._generate_tag(sentence) | tag_list = self._generate_tag(sentence) | ||||
ins[self.new_added_field_name] = tag_list | |||||
dataset.set_target(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return tag_list | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
dataset.set_target(self.new_added_field_name) | |||||
return dataset | return dataset | ||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class CWSBMESTagProcessor(CWSTagProcessor): | class CWSBMESTagProcessor(CWSTagProcessor): | ||||
""" | |||||
通过DataSet中的field_name这个field生成相应的BMES的tag。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -137,6 +153,10 @@ class CWSBMESTagProcessor(CWSTagProcessor): | |||||
return tag_list | return tag_list | ||||
class CWSSegAppTagProcessor(CWSTagProcessor): | class CWSSegAppTagProcessor(CWSTagProcessor): | ||||
""" | |||||
通过DataSet中的field_name这个field生成相应的SegApp的tag。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -151,6 +171,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor): | |||||
class BigramProcessor(Processor): | class BigramProcessor(Processor): | ||||
""" | |||||
这是生成bigram的基类。 | |||||
""" | |||||
def __init__(self, field_name, new_added_fielf_name=None): | def __init__(self, field_name, new_added_fielf_name=None): | ||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | ||||
@@ -158,22 +182,31 @@ class BigramProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
characters = ins[self.field_name] | characters = ins[self.field_name] | ||||
bigrams = self._generate_bigram(characters) | bigrams = self._generate_bigram(characters) | ||||
ins[self.new_added_field_name] = bigrams | |||||
return bigrams | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
def _generate_bigram(self, characters): | def _generate_bigram(self, characters): | ||||
pass | pass | ||||
class Pre2Post2BigramProcessor(BigramProcessor): | class Pre2Post2BigramProcessor(BigramProcessor): | ||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
""" | |||||
该bigram processor生成bigram的方式如下 | |||||
原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为 | |||||
[L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....] | |||||
即每个汉字,会有八个bigram, 对于上例中'a'的bigram为 | |||||
['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac'] | |||||
返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
super(BigramProcessor, self).__init__(field_name, new_added_field_name) | |||||
def _generate_bigram(self, characters): | def _generate_bigram(self, characters): | ||||
bigrams = [] | bigrams = [] | ||||
@@ -197,20 +230,102 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | # 这里需要建立vocabulary了,但是遇到了以下的问题 | ||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | # (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | ||||
# Processor了 | # Processor了 | ||||
# TODO 如何将建立vocab和index这两步统一了? | |||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=1): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose =verbose | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets)==0 and not hasattr(self,'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets)!=0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name) | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||||
def __init__(self, field_name, min_freq=1, max_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||||
def process(self, *datasets): | def process(self, *datasets): | ||||
for dataset in datasets: | for dataset in datasets: | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
@@ -220,19 +335,6 @@ class VocabProcessor(Processor): | |||||
return len(self.vocab) | return len(self.vocab) | ||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
class SegApp2OutputProcessor(Processor): | class SegApp2OutputProcessor(Processor): | ||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | ||||
super(SegApp2OutputProcessor, self).__init__(None, None) | super(SegApp2OutputProcessor, self).__init__(None, None) | ||||
@@ -258,7 +360,32 @@ class SegApp2OutputProcessor(Processor): | |||||
class BMES2OutputProcessor(Processor): | class BMES2OutputProcessor(Processor): | ||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
""" | |||||
按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||||
| | next_B | next_M | next_E | next_S | end | | |||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
举例: | |||||
prediction为BSEMS,会被认为是SSSSS. | |||||
""" | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output', | |||||
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): | |||||
""" | |||||
:param chars_field_name: character所对应的field | |||||
:param tag_field_name: 预测对应的field | |||||
:param new_added_field_name: 转换后的内容所在field | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
""" | |||||
super(BMES2OutputProcessor, self).__init__(None, None) | super(BMES2OutputProcessor, self).__init__(None, None) | ||||
self.chars_field_name = chars_field_name | self.chars_field_name = chars_field_name | ||||
@@ -266,19 +393,55 @@ class BMES2OutputProcessor(Processor): | |||||
self.new_added_field_name = new_added_field_name | self.new_added_field_name = new_added_field_name | ||||
self.b_idx = b_idx | |||||
self.m_idx = m_idx | |||||
self.e_idx = e_idx | |||||
self.s_idx = s_idx | |||||
# 还原init处介绍的矩阵 | |||||
self._valida_matrix = { | |||||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
} | |||||
def _validate_tags(self, tags): | |||||
""" | |||||
给定一个tag的List,返回合法tag | |||||
:param tags: Tensor, shape: (seq_len, ) | |||||
:return: 返回修改为合法tag的list | |||||
""" | |||||
assert len(tags)!=0 | |||||
padded_tags = [-1, *tags, -1] | |||||
for idx in range(len(padded_tags)-1): | |||||
cur_tag = padded_tags[idx] | |||||
if cur_tag not in self._valida_matrix: | |||||
cur_tag = self.s_idx | |||||
if padded_tags[idx+1] not in self._valida_matrix: | |||||
padded_tags[idx+1] = self.s_idx | |||||
next_tag = padded_tags[idx+1] | |||||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||||
if shift_tag[0]!=-1: | |||||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||||
return padded_tags[1:-1] | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
pred_tags = ins[self.tag_field_name] | pred_tags = ins[self.tag_field_name] | ||||
pred_tags = self._validate_tags(pred_tags) | |||||
chars = ins[self.chars_field_name] | chars = ins[self.chars_field_name] | ||||
words = [] | words = [] | ||||
start_idx = 0 | start_idx = 0 | ||||
for idx, tag in enumerate(pred_tags): | for idx, tag in enumerate(pred_tags): | ||||
if tag==3: | |||||
# 当前没有考虑将原文替换回去 | |||||
if tag==self.s_idx: | |||||
words.extend(chars[start_idx:idx+1]) | words.extend(chars[start_idx:idx+1]) | ||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
elif tag==2: | |||||
elif tag==self.e_idx: | |||||
words.append(''.join(chars[start_idx:idx+1])) | words.append(''.join(chars[start_idx:idx+1])) | ||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
ins[self.new_added_field_name] = ' '.join(words) | |||||
return ' '.join(words) | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) |
@@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
return cutted_sentence | return cutted_sentence | ||||
class ConlluPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -70,6 +70,70 @@ class ConlluPOSReader(object): | |||||
return ds | return ds | ||||
class ZhConllPOSReader(object): | |||||
# 中文colln格式reader | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word)==1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | def get_one(self, sample): | ||||
if len(sample)==0: | if len(sample)==0: | ||||
return None | return None | ||||
@@ -84,6 +148,6 @@ class ConlluPOSReader(object): | |||||
return text, pos_tags | return text, pos_tags | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
reader = ConlluPOSReader() | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | d = reader.load('/home/hyan/train.conllx') | ||||
print('reader') | |||||
print(d) |
@@ -4,6 +4,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.metrics import AccuracyMetric | from fastNLP.core.metrics import AccuracyMetric | ||||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||||
from fastNLP.core.metrics import pred_topk, accuracy_topk | from fastNLP.core.metrics import pred_topk, accuracy_topk | ||||
@@ -132,6 +133,234 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
class SpanF1PreRecMetric(unittest.TestCase): | |||||
def test_case1(self): | |||||
from fastNLP.core.metrics import bmes_tag_to_spans | |||||
from fastNLP.core.metrics import bio_tag_to_spans | |||||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||||
expect_bmes_res = set() | |||||
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), | |||||
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) | |||||
expect_bio_res = set() | |||||
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), | |||||
('6', (4, 4)), ('7', (6, 6))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||||
# for i in range(1000): | |||||
# strs = list(map(str, np.random.randint(100, size=1000))) | |||||
# bmes = list('bmes'.upper()) | |||||
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] | |||||
# bio = list('bio'.upper()) | |||||
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||||
def test_case2(self): | |||||
# 测试不带label的 | |||||
from fastNLP.core.metrics import bmes_tag_to_spans | |||||
from fastNLP.core.metrics import bio_tag_to_spans | |||||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||||
expect_bmes_res = set() | |||||
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) | |||||
expect_bio_res = set() | |||||
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||||
# for i in range(1000): | |||||
# bmes = list('bmes'.upper()) | |||||
# bmes_strs = np.random.choice(bmes, size=1000) | |||||
# bio = list('bio'.upper()) | |||||
# bio_strs = np.random.choice(bio, size=100) | |||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||||
def tese_case3(self): | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from collections import Counter | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
# 与allennlp测试能否正确计算f metric | |||||
# | |||||
def generate_allen_tags(encoding_type, number_labels=4): | |||||
vocab = {} | |||||
for i in range(number_labels): | |||||
label = str(i) | |||||
for tag in encoding_type: | |||||
if tag == 'O': | |||||
if tag not in vocab: | |||||
vocab['O'] = len(vocab) + 1 | |||||
continue | |||||
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||||
return vocab | |||||
number_labels = 4 | |||||
# bio tag | |||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||||
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
bio_sequence = torch.FloatTensor( | |||||
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||||
0.0470, 0.0971], | |||||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||||
0.7987, -0.3970], | |||||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||||
0.6880, 1.4348], | |||||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||||
-1.6876, -0.8917], | |||||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||||
1.4217, 0.2622]], | |||||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||||
1.3592, -0.8973], | |||||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||||
-0.4025, -0.3417], | |||||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||||
0.2861, -0.3966], | |||||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||||
0.0213, 1.4777], | |||||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||||
1.3024, 0.2001]]] | |||||
) | |||||
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||||
[5., 6., 8., 6., 0.]]) | |||||
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) | |||||
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, | |||||
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, | |||||
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||||
'f': 0.12499999999994846} | |||||
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||||
#bmes tag | |||||
bmes_sequence = torch.FloatTensor( | |||||
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||||
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||||
-0.3505, -0.6002], | |||||
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, | |||||
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, | |||||
0.4439, 0.2514], | |||||
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, | |||||
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, | |||||
0.5802, -0.0516], | |||||
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, | |||||
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, | |||||
-0.9242, -0.0333], | |||||
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||||
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||||
-0.3779, -0.3195]], | |||||
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||||
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||||
-0.1103, 0.4417], | |||||
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, | |||||
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, | |||||
-0.6386, 0.2797], | |||||
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, | |||||
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, | |||||
1.1237, -1.5385], | |||||
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, | |||||
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, | |||||
-0.0591, -0.2461], | |||||
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, | |||||
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||||
-0.7344, -1.2046]]] | |||||
) | |||||
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||||
[ 6., 15., 6., 15., 5.]]) | |||||
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||||
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||||
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||||
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||||
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||||
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||||
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||||
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||||
'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||||
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||||
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||||
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||||
# from allennlp.training.metrics import SpanBasedF1Measure | |||||
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, | |||||
# non_padded_namespaces=['tags']) | |||||
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') | |||||
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) | |||||
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) | |||||
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) | |||||
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||||
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||||
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
# | |||||
# def convert_allen_res_to_fastnlp_res(metric_result): | |||||
# allen_result = {} | |||||
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} | |||||
# for key, value in metric_result.items(): | |||||
# if key in key_map: | |||||
# key = key_map[key] | |||||
# else: | |||||
# label = key.split('-')[-1] | |||||
# if key.startswith('f1'): | |||||
# key = 'f-{}'.format(label) | |||||
# else: | |||||
# key = '{}-{}'.format(key[:3], label) | |||||
# allen_result[key] = value | |||||
# return allen_result | |||||
# | |||||
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) | |||||
# # print(fastnlp_bio_metric.get_metric()) | |||||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), | |||||
# fastnlp_bio_metric.get_metric()) | |||||
# | |||||
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) | |||||
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') | |||||
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||||
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||||
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||||
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) | |||||
# bmes_target = torch.randint(4 * number_labels, size=(2, 20)) | |||||
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) | |||||
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||||
# | |||||
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) | |||||
# # print(fastnlp_bmes_metric.get_metric()) | |||||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||||
# fastnlp_bmes_metric.get_metric()) | |||||
class TestBMESF1PreRecMetric(unittest.TestCase): | |||||
def test_case1(self): | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
pred = torch.randn(2, 4, 4) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': pred} | |||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
metric.get_metric() | |||||
def test_case2(self): | |||||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': target} | |||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) | |||||
class TestUsefulFunctions(unittest.TestCase): | class TestUsefulFunctions(unittest.TestCase): | ||||
# 测试metrics.py中一些看上去挺有用的函数 | # 测试metrics.py中一些看上去挺有用的函数 | ||||
@@ -0,0 +1,104 @@ | |||||
import unittest | |||||
class TestCRF(unittest.TestCase): | |||||
def test_case1(self): | |||||
# 检查allowed_transitions()能否正确使用 | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||||
allowed_transitions(id2label) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx:label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
def test_case2(self): | |||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||||
pass | |||||
# import torch | |||||
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||||
# | |||||
# labels = ['O'] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BI': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||||
# include_start_end_transitions=False) | |||||
# batch_size = 3 | |||||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||||
# trans_m = allen_CRF.transitions | |||||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||||
# seq_lens[-1] = 20 | |||||
# mask = seq_len_to_byte_mask(seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | |||||
# fast_CRF.trans_m = trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# # score equal | |||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||||
# # seq equal | |||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||||
# | |||||
# | |||||
# labels = [] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BMES': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||||
# include_start_end_transitions=False) | |||||
# batch_size = 3 | |||||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||||
# trans_m = allen_CRF.transitions | |||||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||||
# seq_lens[-1] = 20 | |||||
# mask = seq_len_to_byte_mask(seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
# encoding_type='BMES')) | |||||
# fast_CRF.trans_m = trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# # score equal | |||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||||
# # seq equal | |||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||||