# Conflicts: # fastNLP/core/dataset.pytags/v0.3.0^2
@@ -9,16 +9,15 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.model_zoo import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.batch import Batch | |||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||
from fastNLP.core.tester import Tester | |||
# TODO add pretrain urls | |||
model_urls = { | |||
@@ -29,6 +28,7 @@ model_urls = { | |||
class API: | |||
def __init__(self): | |||
self.pipeline = None | |||
self._dict = None | |||
def predict(self, *args, **kwargs): | |||
raise NotImplementedError | |||
@@ -38,8 +38,8 @@ class API: | |||
_dict = torch.load(path, map_location='cpu') | |||
else: | |||
_dict = load_url(path, map_location='cpu') | |||
self.pipeline = _dict['pipeline'] | |||
self._dict = _dict | |||
self.pipeline = _dict['pipeline'] | |||
for processor in self.pipeline.pipeline: | |||
if isinstance(processor, ModelProcessor): | |||
processor.set_model_device(device) | |||
@@ -48,8 +48,10 @@ class API: | |||
class POS(API): | |||
"""FastNLP API for Part-Of-Speech tagging. | |||
""" | |||
:param str model_path: the path to the model. | |||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||
""" | |||
def __init__(self, model_path=None, device='cpu'): | |||
super(POS, self).__init__() | |||
if model_path is None: | |||
@@ -75,12 +77,28 @@ class POS(API): | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
dataset.add_field('words', sentence_list) | |||
dataset.add_field("words", sentence_list) | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
output = dataset['word_pos_output'].content | |||
def decode_tags(ins): | |||
pred_tags = ins["tag"] | |||
chars = ins["words"] | |||
words = [] | |||
start_idx = 0 | |||
for idx, tag in enumerate(pred_tags): | |||
if tag[0] == "S": | |||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
elif tag[0] == "E": | |||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
return words | |||
dataset.apply(decode_tags, new_field_name="tag_output") | |||
output = dataset.field_arrays["tag_output"].content | |||
if isinstance(content, str): | |||
return output[0] | |||
elif isinstance(content, list): | |||
@@ -95,9 +113,10 @@ class POS(API): | |||
pipeline.append(tag_proc) | |||
pp = Pipeline(pipeline) | |||
reader = ConlluPOSReader() | |||
reader = ConllPOSReader() | |||
te_dataset = reader.load(filepath) | |||
""" | |||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||
end_tagidx_set = set() | |||
tag_proc.vocab.build_vocab() | |||
@@ -108,15 +127,16 @@ class POS(API): | |||
end_tagidx_set.add(value) | |||
evaluator.end_tagidx_set = end_tagidx_set | |||
default_valid_args = {"batch_size": 64, | |||
"use_cuda": True, "evaluator": evaluator} | |||
pp(te_dataset) | |||
te_dataset.set_target(truth=True) | |||
default_valid_args = {"batch_size": 64, | |||
"use_cuda": True, "evaluator": evaluator, | |||
"model": model, "data": te_dataset} | |||
tester = Tester(**default_valid_args) | |||
test_result = tester.test(model, te_dataset) | |||
test_result = tester.test() | |||
f1 = round(test_result['F'] * 100, 2) | |||
pre = round(test_result['P'] * 100, 2) | |||
@@ -124,6 +144,7 @@ class POS(API): | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
return f1, pre, rec | |||
""" | |||
class CWS(API): | |||
@@ -168,7 +189,7 @@ class CWS(API): | |||
pipeline.insert(1, tag_proc) | |||
pp = Pipeline(pipeline) | |||
reader = ConlluCWSReader() | |||
reader = ConllCWSReader() | |||
# te_filename = '/home/hyan/ctb3/test.conllx' | |||
te_dataset = reader.load(filepath) | |||
@@ -290,13 +311,13 @@ class Analyzer: | |||
if __name__ == "__main__": | |||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||
# pos = POS(device='cpu') | |||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
# '那么这款无人机到底有多厉害?'] | |||
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||
pos = POS(pos_model_path, device='cpu') | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||
# print(pos.predict(s)) | |||
print(pos.predict(s)) | |||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||
# cws = CWS(device='cpu') | |||
@@ -306,9 +327,9 @@ if __name__ == "__main__": | |||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | |||
# print(cws.predict(s)) | |||
parser = Parser(device='cpu') | |||
# parser = Parser(device='cpu') | |||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
print(parser.predict(s)) | |||
# print(parser.predict(s)) |
@@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary | |||
class Processor(object): | |||
def __init__(self, field_name, new_added_field_name): | |||
""" | |||
:param field_name: 处理哪个field | |||
:param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field | |||
""" | |||
self.field_name = field_name | |||
if new_added_field_name is None: | |||
self.new_added_field_name = field_name | |||
@@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||
class PreAppendProcessor(Processor): | |||
""" | |||
向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 | |||
[data] + instance[field_name] | |||
""" | |||
def __init__(self, data, field_name, new_added_field_name=None): | |||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||
self.data = data | |||
@@ -102,6 +112,10 @@ class PreAppendProcessor(Processor): | |||
class SliceProcessor(Processor): | |||
""" | |||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||
""" | |||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||
for o in (start, end, step): | |||
@@ -114,7 +128,17 @@ class SliceProcessor(Processor): | |||
class Num2TagProcessor(Processor): | |||
""" | |||
将一句话中的数字转换为某个tag。 | |||
""" | |||
def __init__(self, tag, field_name, new_added_field_name=None): | |||
""" | |||
:param tag: str, 将数字转换为该tag | |||
:param field_name: | |||
:param new_added_field_name: | |||
""" | |||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||
self.tag = tag | |||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||
@@ -135,6 +159,10 @@ class Num2TagProcessor(Processor): | |||
class IndexerProcessor(Processor): | |||
""" | |||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||
['我', '是', xxx] | |||
""" | |||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||
@@ -163,19 +191,19 @@ class IndexerProcessor(Processor): | |||
class VocabProcessor(Processor): | |||
"""Build vocabulary with a field in the data set. | |||
""" | |||
传入若干个DataSet以建立vocabulary。 | |||
""" | |||
def __init__(self, field_name): | |||
def __init__(self, field_name, min_freq=1, max_size=None): | |||
super(VocabProcessor, self).__init__(field_name, None) | |||
self.vocab = Vocabulary() | |||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||
def process(self, *datasets): | |||
for dataset in datasets: | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
self.vocab.update(ins[self.field_name]) | |||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
def get_vocab(self): | |||
self.vocab.build_vocab() | |||
@@ -183,6 +211,10 @@ class VocabProcessor(Processor): | |||
class SeqLenProcessor(Processor): | |||
""" | |||
根据某个field新增一个sequence length的field。取该field的第一维 | |||
""" | |||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
self.is_input = is_input | |||
@@ -195,10 +227,15 @@ class SeqLenProcessor(Processor): | |||
return dataset | |||
from fastNLP.core.utils import _build_args | |||
class ModelProcessor(Processor): | |||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||
""" | |||
迭代模型并将结果的padding drop掉 | |||
传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. | |||
model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 | |||
(Batch_size, 1),则使用seqence length这个field进行unpad | |||
TODO 这个类需要删除对seq_lens的依赖。 | |||
:param seq_len_field_name: | |||
:param batch_size: | |||
@@ -211,13 +248,18 @@ class ModelProcessor(Processor): | |||
def process(self, dataset): | |||
self.model.eval() | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | |||
batch_output = defaultdict(list) | |||
if hasattr(self.model, "predict"): | |||
predict_func = self.model.predict | |||
else: | |||
predict_func = self.model.forward | |||
with torch.no_grad(): | |||
for batch_x, _ in data_iterator: | |||
prediction = self.model.predict(**batch_x) | |||
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||
refined_batch_x = _build_args(predict_func, **batch_x) | |||
prediction = predict_func(**refined_batch_x) | |||
seq_lens = batch_x[self.seq_len_field_name].tolist() | |||
for key, value in prediction.items(): | |||
tmp_batch = [] | |||
@@ -246,6 +288,10 @@ class ModelProcessor(Processor): | |||
class Index2WordProcessor(Processor): | |||
""" | |||
将DataSet中某个为index的field根据vocab转换为str | |||
""" | |||
def __init__(self, vocab, field_name, new_added_field_name): | |||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||
self.vocab = vocab | |||
@@ -266,5 +312,5 @@ class SetIsTargetProcessor(Processor): | |||
def process(self, dataset): | |||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||
set_dict.update(self.field_dict) | |||
dataset.set_target(**set_dict) | |||
dataset.set_target(*set_dict.keys()) | |||
return dataset |
@@ -10,10 +10,10 @@ class Batch(object): | |||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||
# ... | |||
:param dataset: a DataSet object | |||
:param batch_size: int, the size of the batch | |||
:param sampler: a Sampler object | |||
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||
:param DataSet dataset: a DataSet object | |||
:param int batch_size: the size of the batch | |||
:param Sampler sampler: a Sampler object | |||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | |||
""" | |||
@@ -254,6 +254,8 @@ class DataSet(object): | |||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||
""" | |||
results = [func(ins) for ins in self._inner_iter()] | |||
if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
extra_param = {} | |||
if 'is_input' in kwargs: | |||
@@ -261,8 +263,6 @@ class DataSet(object): | |||
if 'is_target' in kwargs: | |||
extra_param['is_target'] = kwargs['is_target'] | |||
if new_field_name is not None: | |||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
if new_field_name in self.field_arrays: | |||
# overwrite the field, keep same attributes | |||
old_field = self.field_arrays[new_field_name] | |||
@@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||
return loss | |||
@@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class MetricBase(object): | |||
@@ -80,11 +81,6 @@ class MetricBase(object): | |||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||
f"initialization parameters, or change its signature.") | |||
# evaluate should not have varargs. | |||
# if func_spect.varargs: | |||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||
# f"positional argument.).") | |||
def get_metric(self, reset=True): | |||
raise NotImplemented | |||
@@ -108,10 +104,9 @@ class MetricBase(object): | |||
This method will call self.evaluate method. | |||
Before calling self.evaluate, it will first check the validity of output_dict, target_dict | |||
(1) whether self.evaluate has varargs, which is not supported. | |||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
(1) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
will be conducted.) | |||
@@ -299,6 +294,368 @@ class AccuracyMetric(MetricBase): | |||
self.total = 0 | |||
return evaluate_result | |||
def bmes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def bio_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bio_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bio_tag, label = tag[:1], tag[2:] | |||
if bio_tag == 'b': | |||
spans.append((label, [idx, idx])) | |||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bio_tag == 'o': # o tag does not count | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bio_tag = bio_tag | |||
return [(span[0], (span[1][0], span[1][1])) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
class SpanFPreRecMetric(MetricBase): | |||
""" | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
最后得到的metric结果为 | |||
{ | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值 | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||
only_gross=True, f_type='micro', beta=1): | |||
""" | |||
:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||
:param encoding_type: str, 目前支持bio, bmes | |||
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
label的f1, pre, rec | |||
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
encoding_type = encoding_type.lower() | |||
if encoding_type not in ('bio', 'bmes'): | |||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = bio_tag_to_spans | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta**2 | |||
self.only_gross = only_gross | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||
self.tag_vocab = tag_vocab | |||
self._true_positives = defaultdict(int) | |||
self._false_positives = defaultdict(int) | |||
self._false_negatives = defaultdict(int) | |||
def evaluate(self, pred, target, seq_lens): | |||
""" | |||
A lot of design idea comes from allennlp's measure | |||
:param pred: | |||
:param target: | |||
:param seq_lens: | |||
:return: | |||
""" | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if not isinstance(seq_lens, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_lens)}.") | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
pred = pred.argmax(dim=-1) | |||
num_classes = pred.size(-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
batch_size = pred.size(0) | |||
for i in range(batch_size): | |||
pred_tags = pred[i, :seq_lens[i]].tolist() | |||
gold_tags = target[i, :seq_lens[i]].tolist() | |||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||
for span in pred_spans: | |||
if span in gold_spans: | |||
self._true_positives[span[0]] += 1 | |||
gold_spans.remove(span) | |||
else: | |||
self._false_positives[span[0]] += 1 | |||
for span in gold_spans: | |||
self._false_negatives[span[0]] += 1 | |||
def get_metric(self, reset=True): | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type=='macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(set(self._false_positives.keys())) | |||
tags.update(set(self._true_positives.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
tp = self._true_positives[tag] | |||
fn = self._false_negatives[tag] | |||
fp = self._false_positives[tag] | |||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum + rec | |||
if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
rec_key = 'rec-{}'.format(tag) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum/len(tags) | |||
evaluate_result['pre'] = pre_sum/len(tags) | |||
evaluate_result['rec'] = rec_sum/len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||
sum(self._false_negatives.values()), | |||
sum(self._false_positives.values())) | |||
evaluate_result['f'] = round(f, 6) | |||
evaluate_result['pre'] = round(pre, 6) | |||
evaluate_result['rec'] = round(rec, 6) | |||
if reset: | |||
self._true_positives = defaultdict(int) | |||
self._false_positives = defaultdict(int) | |||
self._false_negatives = defaultdict(int) | |||
return evaluate_result | |||
def _compute_f_pre_rec(self, tp, fn, fp): | |||
""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class BMESF1PreRecMetric(MetricBase): | |||
""" | |||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
| | next_B | next_M | next_E | next_S | end | | |||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
举例: | |||
prediction为BSEMS,会被认为是SSSSS. | |||
本Metric不检验target的合法性,请务必保证target的合法性。 | |||
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 | |||
target形状为 (batch_size, max_len) | |||
seq_lens形状为 (batch_size, ) | |||
""" | |||
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): | |||
""" | |||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||
:param b_idx: int, Begin标签所对应的tag idx. | |||
:param m_idx: int, Middle标签所对应的tag idx. | |||
:param e_idx: int, End标签所对应的tag idx. | |||
:param s_idx: int, Single标签所对应的tag idx | |||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 | |||
""" | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||
self.yt_wordnum = 0 | |||
self.yp_wordnum = 0 | |||
self.corr_num = 0 | |||
self.b_idx = b_idx | |||
self.m_idx = m_idx | |||
self.e_idx = e_idx | |||
self.s_idx = s_idx | |||
# 还原init处介绍的矩阵 | |||
self._valida_matrix = { | |||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
} | |||
def _validate_tags(self, tags): | |||
""" | |||
给定一个tag的Tensor,返回合法tag | |||
:param tags: Tensor, shape: (seq_len, ) | |||
:return: 返回修改为合法tag的list | |||
""" | |||
assert len(tags)!=0 | |||
assert isinstance(tags, torch.Tensor) and len(tags.size())==1 | |||
padded_tags = [-1, *tags.tolist(), -1] | |||
for idx in range(len(padded_tags)-1): | |||
cur_tag = padded_tags[idx] | |||
if cur_tag not in self._valida_matrix: | |||
cur_tag = self.s_idx | |||
if padded_tags[idx+1] not in self._valida_matrix: | |||
padded_tags[idx+1] = self.s_idx | |||
next_tag = padded_tags[idx+1] | |||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||
if shift_tag[0]!=-1: | |||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||
return padded_tags[1:-1] | |||
def evaluate(self, pred, target, seq_lens): | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if not isinstance(seq_lens, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_lens)}.") | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
pred = pred.argmax(dim=-1) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
for idx in range(len(pred)): | |||
seq_len = seq_lens[idx] | |||
target_tags = target[idx][:seq_len].tolist() | |||
pred_tags = pred[idx][:seq_len] | |||
pred_tags = self._validate_tags(pred_tags) | |||
start_idx = 0 | |||
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): | |||
if t_tag in (self.s_idx, self.e_idx): | |||
self.yt_wordnum += 1 | |||
corr_flag = True | |||
for i in range(start_idx, t_idx+1): | |||
if target_tags[i]!=pred_tags[i]: | |||
corr_flag = False | |||
if corr_flag: | |||
self.corr_num += 1 | |||
start_idx = t_idx + 1 | |||
if p_tag in (self.s_idx, self.e_idx): | |||
self.yp_wordnum += 1 | |||
def get_metric(self, reset=True): | |||
P = self.corr_num / (self.yp_wordnum + 1e-12) | |||
R = self.corr_num / (self.yt_wordnum + 1e-12) | |||
F = 2 * P * R / (P + R + 1e-12) | |||
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} | |||
if reset: | |||
self.yp_wordnum = 0 | |||
self.yt_wordnum = 0 | |||
self.corr_num = 0 | |||
return evaluate_result | |||
def _prepare_metrics(metrics): | |||
""" | |||
@@ -27,7 +27,10 @@ from fastNLP.core.utils import get_func_signature | |||
class Trainer(object): | |||
""" | |||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False): | |||
""" | |||
:param DataSet train_data: the training data | |||
:param torch.nn.modules.module model: a PyTorch model | |||
:param LossBase loss: a loss object | |||
@@ -48,16 +51,10 @@ class Trainer(object): | |||
smaller, add "-" in front of the string. For example:: | |||
metric_key="-PPL" # language model gets better as perplexity gets smaller | |||
:param BaseSampler sampler: method used to generate batch data. | |||
:param bool use_tqdm: whether to use tqdm to show train progress. | |||
""" | |||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||
""" | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
@@ -3,7 +3,9 @@ import os | |||
class BaseLoader(object): | |||
"""Base loader for all loaders. | |||
""" | |||
def __init__(self): | |||
super(BaseLoader, self).__init__() | |||
@@ -32,7 +34,9 @@ class BaseLoader(object): | |||
class DataLoaderRegister: | |||
""""register for data sets""" | |||
"""Register for all data sets. | |||
""" | |||
_readers = {} | |||
@classmethod | |||
@@ -6,7 +6,11 @@ from fastNLP.io.base_loader import BaseLoader | |||
class ConfigLoader(BaseLoader): | |||
"""loader for configuration files""" | |||
"""Loader for configuration. | |||
:param str data_path: path to the config | |||
""" | |||
def __init__(self, data_path=None): | |||
super(ConfigLoader, self).__init__() | |||
@@ -19,13 +23,15 @@ class ConfigLoader(BaseLoader): | |||
@staticmethod | |||
def load_config(file_path, sections): | |||
""" | |||
:param file_path: the path of config file | |||
:param sections: the dict of {section_name(string): Section instance} | |||
Example: | |||
"""Load section(s) of configuration into the ``sections`` provided. No returns. | |||
:param str file_path: the path of config file | |||
:param dict sections: the dict of ``{section_name(string): ConfigSection object}`` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
:return: return nothing, but the value of attributes are saved in sessions | |||
""" | |||
assert isinstance(sections, dict) | |||
cfg = configparser.ConfigParser() | |||
@@ -60,9 +66,12 @@ class ConfigLoader(BaseLoader): | |||
class ConfigSection(object): | |||
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file. | |||
""" | |||
def __init__(self): | |||
pass | |||
super(ConfigSection, self).__init__() | |||
def __getitem__(self, key): | |||
""" | |||
@@ -132,25 +141,12 @@ class ConfigSection(object): | |||
return self.__dict__ | |||
if __name__ == "__main__": | |||
config = ConfigLoader('there is no data') | |||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||
""" | |||
General and My can be found in config file, so the attr and | |||
value will be updated | |||
A cannot be found in config file, so nothing will be done | |||
""" | |||
config.load_config("../../test/data_for_tests/config", section) | |||
for s in section: | |||
print(s) | |||
for attr in section[s].__dict__.keys(): | |||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||
class ConfigSaver(object): | |||
"""ConfigSaver is used to save config file and solve related conflicts. | |||
:param str file_path: path to the config file | |||
""" | |||
def __init__(self, file_path): | |||
self.file_path = file_path | |||
if not os.path.exists(self.file_path): | |||
@@ -244,9 +240,8 @@ class ConfigSaver(object): | |||
def save_config_file(self, section_name, section): | |||
"""This is the function to be called to change the config file with a single section and its name. | |||
:param section_name: The name of section what needs to be changed and saved. | |||
:param section: The section with key and value what needs to be changed and saved. | |||
:return: | |||
:param str section_name: The name of section what needs to be changed and saved. | |||
:param ConfigSection section: The section with key and value what needs to be changed and saved. | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
@@ -9,11 +9,12 @@ def convert_seq_dataset(data): | |||
"""Create an DataSet instance that contains no labels. | |||
:param data: list of list of strings, [num_examples, *]. | |||
:: | |||
[ | |||
[word_11, word_12, ...], | |||
... | |||
] | |||
Example:: | |||
[ | |||
[word_11, word_12, ...], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
@@ -24,15 +25,16 @@ def convert_seq_dataset(data): | |||
def convert_seq2tag_dataset(data): | |||
"""Convert list of data into DataSet | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, *]. | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
@@ -43,15 +45,16 @@ def convert_seq2tag_dataset(data): | |||
def convert_seq2seq_dataset(data): | |||
"""Convert list of data into DataSet | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, *]. | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
@@ -62,20 +65,31 @@ def convert_seq2seq_dataset(data): | |||
class DataSetLoader: | |||
""""loader for data sets""" | |||
"""Interface for all DataSetLoaders. | |||
""" | |||
def load(self, path): | |||
""" load data in `path` into a dataset | |||
"""Load data from a given file. | |||
:param str path: file path | |||
:return: a DataSet object | |||
""" | |||
raise NotImplementedError | |||
def convert(self, data): | |||
"""convert list of data into dataset | |||
"""Optional operation to build a DataSet. | |||
:param data: inner data structure (user-defined) to represent the data. | |||
:return: a DataSet object | |||
""" | |||
raise NotImplementedError | |||
class NativeDataSetLoader(DataSetLoader): | |||
"""A simple example of DataSetLoader | |||
""" | |||
def __init__(self): | |||
super(NativeDataSetLoader, self).__init__() | |||
@@ -90,6 +104,9 @@ DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||
class RawDataSetLoader(DataSetLoader): | |||
"""A simple example of raw data reader | |||
""" | |||
def __init__(self): | |||
super(RawDataSetLoader, self).__init__() | |||
@@ -108,37 +125,35 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||
class POSDataSetLoader(DataSetLoader): | |||
"""Dataset Loader for POS Tag datasets. | |||
In these datasets, each line are divided by '\t' | |||
while the first Col is the vocabulary and the second | |||
Col is the label. | |||
Different sentence are divided by an empty line. | |||
e.g: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
(separated by an empty line) | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this file, there are two sentence "Tom and Jerry ." | |||
and "Hello world !". Each word has its own label from label1 | |||
to label5. | |||
"""Dataset Loader for a POS Tag dataset. | |||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | |||
Col is the label. Different sentence are divided by an empty line. | |||
E.g:: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
(separated by an empty line) | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | |||
""" | |||
def __init__(self): | |||
super(POSDataSetLoader, self).__init__() | |||
def load(self, data_path): | |||
""" | |||
:return data: three-level list | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
@@ -188,17 +203,17 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
super(TokenizeDataSetLoader, self).__init__() | |||
def load(self, data_path, max_seq_len=32): | |||
""" | |||
load pku dataset for Chinese word segmentation | |||
"""Load pku dataset for Chinese word segmentation. | |||
CWS (Chinese Word Segmentation) pku training dataset format: | |||
1. Each line is a sentence. | |||
2. Each word in a sentence is separated by space. | |||
1. Each line is a sentence. | |||
2. Each word in a sentence is separated by space. | |||
This function convert the pku dataset into three-level lists with labels <BMES>. | |||
B: beginning of a word | |||
M: middle of a word | |||
E: ending of a word | |||
S: single character | |||
B: beginning of a word | |||
M: middle of a word | |||
E: ending of a word | |||
S: single character | |||
:param str data_path: path to the data set. | |||
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | |||
several sequences. | |||
:return: three-level lists | |||
@@ -254,11 +269,9 @@ class ClassDataSetLoader(DataSetLoader): | |||
@staticmethod | |||
def parse(lines): | |||
""" | |||
Params | |||
lines: lines from dataset | |||
Return | |||
list(list(list())): the three level of lists are | |||
words, sentence, and dataset | |||
:param lines: lines from dataset | |||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | |||
""" | |||
dataset = list() | |||
for line in lines: | |||
@@ -280,15 +293,9 @@ class ConllLoader(DataSetLoader): | |||
"""loader for conll format files""" | |||
def __init__(self): | |||
""" | |||
:param str data_path: the path to the conll data set | |||
""" | |||
super(ConllLoader, self).__init__() | |||
def load(self, data_path): | |||
""" | |||
:return: list lines: all lines in a conll file | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
@@ -320,8 +327,8 @@ class ConllLoader(DataSetLoader): | |||
class LMDataSetLoader(DataSetLoader): | |||
"""Language Model Dataset Loader | |||
This loader produces data for language model training in a supervised way. | |||
That means it has X and Y. | |||
This loader produces data for language model training in a supervised way. | |||
That means it has X and Y. | |||
""" | |||
@@ -467,6 +474,7 @@ class Conll2003Loader(DataSetLoader): | |||
return dataset | |||
class SNLIDataSetLoader(DataSetLoader): | |||
"""A data set loader for SNLI data set. | |||
@@ -478,8 +486,8 @@ class SNLIDataSetLoader(DataSetLoader): | |||
def load(self, path_list): | |||
""" | |||
:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
:return: data_set: A DataSet object. | |||
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
:return: A DataSet object. | |||
""" | |||
assert len(path_list) == 3 | |||
line_set = [] | |||
@@ -507,12 +515,14 @@ class SNLIDataSetLoader(DataSetLoader): | |||
"""Convert a 3D list to a DataSet object. | |||
:param data: A 3D tensor. | |||
[ | |||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
... | |||
] | |||
:return: data_set: A DataSet object. | |||
Example:: | |||
[ | |||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
... | |||
] | |||
:return: A DataSet object. | |||
""" | |||
data_set = DataSet() | |||
@@ -38,7 +38,7 @@ class EmbedLoader(BaseLoader): | |||
:param str emb_file: the pre-trained embedding file path | |||
:param str emb_type: the pre-trained embedding data format | |||
:return dict embedding: `{str: np.array}` | |||
:return: a dict of ``{str: np.array}`` | |||
""" | |||
if emb_type == 'glove': | |||
return EmbedLoader._load_glove(emb_file) | |||
@@ -53,8 +53,9 @@ class EmbedLoader(BaseLoader): | |||
:param str emb_file: the pre-trained embedding file path. | |||
:param str emb_type: the pre-trained embedding format, support glove now | |||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||
vocab: input vocab or vocab built by pre-train | |||
:return (embedding_tensor, vocab): | |||
embedding_tensor - Tensor of shape (len(word_dict), emb_dim); | |||
vocab - input vocab or vocab built by pre-train | |||
""" | |||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||
@@ -95,7 +96,7 @@ class EmbedLoader(BaseLoader): | |||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||
:param str emb_file: the pre-trained embedding file path. | |||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:return numpy.ndarray embedding_matrix: | |||
:return embedding_matrix: numpy.ndarray | |||
""" | |||
if vocab is None: | |||
@@ -3,15 +3,16 @@ import os | |||
def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | |||
"""Return a logger. | |||
"""Create a logger. | |||
:param logger_name: str | |||
:param log_path: str | |||
:param str logger_name: | |||
:param str log_path: | |||
:param log_format: | |||
:param log_level: | |||
:return: logger | |||
to use a logger: | |||
To use a logger:: | |||
logger.debug("this is a debug message") | |||
logger.info("this is a info message") | |||
logger.warning("this is a warning message") | |||
@@ -13,10 +13,10 @@ class ModelLoader(BaseLoader): | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
""" | |||
Load model parameters from .pkl files into the empty PyTorch model. | |||
"""Load model parameters from ".pkl" files into the empty PyTorch model. | |||
:param empty_model: a PyTorch model with initialized parameters. | |||
:param model_path: str, the path to the saved model. | |||
:param str model_path: the path to the saved model. | |||
""" | |||
empty_model.load_state_dict(torch.load(model_path)) | |||
@@ -24,30 +24,30 @@ class ModelLoader(BaseLoader): | |||
def load_pytorch_model(model_path): | |||
"""Load the entire model. | |||
:param str model_path: the path to the saved model. | |||
""" | |||
return torch.load(model_path) | |||
class ModelSaver(object): | |||
"""Save a model | |||
:param str save_path: the path to the saving directory. | |||
Example:: | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
saver.save_pytorch(model) | |||
""" | |||
def __init__(self, save_path): | |||
""" | |||
:param save_path: str, the path to the saving directory. | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
"""Save a pytorch model into .pkl file. | |||
"""Save a pytorch model into ".pkl" file. | |||
:param model: a PyTorch model | |||
:param param_only: bool, whether only to save the model parameters or the entire model. | |||
:param bool param_only: whether only to save the model parameters or the entire model. | |||
""" | |||
if param_only is True: | |||
@@ -1,8 +1,8 @@ | |||
import torch | |||
import numpy as np | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder | |||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||
from fastNLP.modules.utils import seq_mask | |||
@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||
Advanced Sequence Labeling Model | |||
""" | |||
def __init__(self, args, emb=None): | |||
def __init__(self, args, emb=None, id2words=None): | |||
super(AdvSeqLabel, self).__init__(args) | |||
vocab_size = args["vocab_size"] | |||
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||
bidirectional=True, batch_first=True) | |||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||
self.drop = torch.nn.Dropout(dropout) | |||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
if id2words is None: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
else: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
allowed_transitions=allowed_transitions(id2words, | |||
encoding_type="bmes")) | |||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
""" | |||
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||
assert 'loss' in kwargs | |||
return kwargs['loss'] | |||
if __name__ == '__main__': | |||
args = { | |||
'vocab_size': 20, | |||
@@ -208,11 +215,11 @@ if __name__ == '__main__': | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
print('loss: {} acc {}'.format(loss.item(), | |||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
@@ -15,30 +15,153 @@ def seq_len_to_byte_mask(seq_lens): | |||
# return value: ByteTensor, batch_size x max_len | |||
batch_size = seq_lens.size(0) | |||
max_len = seq_lens.max() | |||
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1) | |||
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) | |||
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) | |||
return mask | |||
def allowed_transitions(id2label, encoding_type='bio'): | |||
""" | |||
:param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes"。 | |||
:return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
""" | |||
num_tags = len(id2label) | |||
start_idx = num_tags | |||
end_idx = num_tags + 1 | |||
encoding_type = encoding_type.lower() | |||
allowed_trans = [] | |||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||
def split_tag_label(from_label): | |||
from_label = from_label.lower() | |||
if from_label in ['start', 'end']: | |||
from_tag = from_label | |||
from_label = '' | |||
else: | |||
from_tag = from_label[:1] | |||
from_label = from_label[2:] | |||
return from_tag, from_label | |||
for from_id, from_label in id_label_lst: | |||
if from_label in ['<pad>', '<unk>']: | |||
continue | |||
from_tag, from_label = split_tag_label(from_label) | |||
for to_id, to_label in id_label_lst: | |||
if to_label in ['<pad>', '<unk>']: | |||
continue | |||
to_tag, to_label = split_tag_label(to_label) | |||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
allowed_trans.append((from_id, to_id)) | |||
return allowed_trans | |||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param from_label: str, 比如"PER", "LOC"等label | |||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param to_label: str, 比如"PER", "LOC"等label | |||
:return: bool,能否跃迁 | |||
""" | |||
if to_tag=='start' or from_tag=='end': | |||
return False | |||
encoding_type = encoding_type.lower() | |||
if encoding_type == 'bio': | |||
""" | |||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+-------+-----+ | |||
| | B | I | O | start | end | | |||
+-------+---+---+---+-------+-----+ | |||
| B | y | - | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| I | y | - | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| O | y | n | y | n | y | | |||
+-------+---+---+---+-------+-----+ | |||
| start | y | n | y | n | n | | |||
+-------+---+---+---+-------+-----+ | |||
| end | n | n | n | n | n | | |||
+-------+---+---+---+-------+-----+ | |||
""" | |||
if from_tag == 'start': | |||
return to_tag in ('b', 'o') | |||
elif from_tag in ['b', 'i']: | |||
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||
elif from_tag == 'o': | |||
return to_tag in ['end', 'b', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||
elif encoding_type == 'bmes': | |||
""" | |||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||
+-------+---+---+---+---+-------+-----+ | |||
| | B | M | E | S | start | end | | |||
+-------+---+---+---+---+-------+-----+ | |||
| B | n | - | - | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| M | n | - | - | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| E | y | n | n | y | n | y | | |||
+-------+---+---+---+---+-------+-----+ | |||
| S | y | n | n | y | n | y | | |||
+-------+---+---+---+---+-------+-----+ | |||
| start | y | n | n | y | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
| end | n | n | n | n | n | n | | |||
+-------+---+---+---+---+-------+-----+ | |||
""" | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's'] | |||
elif from_tag == 'b': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag == 'm': | |||
return to_tag in ['m', 'e'] and from_label==to_label | |||
elif from_tag in ['e', 's']: | |||
return to_tag in ['b', 's', 'end'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
class ConditionalRandomField(nn.Module): | |||
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None): | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||
""" | |||
:param tag_size: int, num of tags | |||
:param include_start_end_trans: bool, whether to include start/end tag | |||
:param num_tags: int, 标签的数量。 | |||
:param include_start_end_trans: bool, 是否包含起始tag | |||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。 | |||
如果为None,则所有跃迁均为合法 | |||
:param initial_method: | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
self.tag_size = tag_size | |||
self.num_tags = num_tags | |||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | |||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||
if self.include_start_end_trans: | |||
self.start_scores = nn.Parameter(torch.randn(tag_size)) | |||
self.end_scores = nn.Parameter(torch.randn(tag_size)) | |||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||
if allowed_transitions is None: | |||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
else: | |||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||
for from_tag_id, to_tag_id in allowed_transitions: | |||
constrain[from_tag_id, to_tag_id] = 0 | |||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
# self.reset_parameter() | |||
initial_parameter(self, initial_method) | |||
def reset_parameter(self): | |||
nn.init.xavier_normal_(self.trans_m) | |||
if self.include_start_end_trans: | |||
@@ -49,7 +172,7 @@ class ConditionalRandomField(nn.Module): | |||
""" | |||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
sum of the likelihoods across all possible state sequences. | |||
:param logits:FloatTensor, max_len x batch_size x tag_size | |||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||
:param mask:ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
""" | |||
@@ -72,7 +195,7 @@ class ConditionalRandomField(nn.Module): | |||
def _glod_score(self, logits, tags, mask): | |||
""" | |||
Compute the score for the gold path. | |||
:param logits: FloatTensor, max_len x batch_size x tag_size | |||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||
:param tags: LongTensor, max_len x batch_size | |||
:param mask: ByteTensor, max_len x batch_size | |||
:return:FloatTensor, batch_size | |||
@@ -99,7 +222,7 @@ class ConditionalRandomField(nn.Module): | |||
def forward(self, feats, tags, mask): | |||
""" | |||
Calculate the neg log likelihood | |||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||
:param tags:LongTensor, batch_size x max_len | |||
:param mask:ByteTensor batch_size x max_len | |||
:return:FloatTensor, batch_size | |||
@@ -112,13 +235,20 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, data, mask, get_score=False): | |||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||
""" | |||
Given a feats matrix, return best decode path and best score. | |||
:param data:FloatTensor, batch_size x max_len x tag_size | |||
:param data:FloatTensor, batch_size x max_len x num_tags | |||
:param mask:ByteTensor batch_size x max_len | |||
:param get_score: bool, whether to output the decode score. | |||
:return: scores, paths | |||
:param unpad: bool, 是否将结果unpad, | |||
如果False, 返回的是batch_size x max_len的tensor, | |||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||
List[int]的长度是这个sample的有效长度 | |||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||
为每个seqence的解码分数。 | |||
""" | |||
batch_size, seq_len, n_tags = data.size() | |||
data = data.transpose(0, 1).data # L, B, H | |||
@@ -127,19 +257,23 @@ class ConditionalRandomField(nn.Module): | |||
# dp | |||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = data[0] | |||
transitions = self._constrain.data.clone() | |||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||
if self.include_start_end_trans: | |||
vscore += self.start_scores.view(1, -1) | |||
transitions[n_tags, :n_tags] += self.start_scores.data | |||
transitions[:n_tags, n_tags+1] += self.end_scores.data | |||
vscore += transitions[n_tags, :n_tags] | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = data[i].view(batch_size, 1, n_tags) | |||
trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | |||
if self.include_start_end_trans: | |||
vscore += self.end_scores.view(1, -1) | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||
@@ -154,7 +288,13 @@ class ConditionalRandomField(nn.Module): | |||
for i in range(seq_len - 1): | |||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
ans[idxes[i+1], batch_idx] = last_tags | |||
ans = ans.transpose(0, 1) | |||
if unpad: | |||
paths = [] | |||
for idx, seq_len in enumerate(lens): | |||
paths.append(ans[idx, :seq_len+1].tolist()) | |||
else: | |||
paths = ans | |||
if get_score: | |||
return ans_score, ans.transpose(0, 1) | |||
return ans.transpose(0, 1) | |||
return paths, ans_score.tolist() | |||
return paths |
@@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader | |||
def cut_long_sentence(sent, max_sample_length=200): | |||
""" | |||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||
:param sent: str. | |||
:param max_sample_length: int. | |||
:return: list of str. | |||
""" | |||
sent_no_space = sent.replace(' ', '') | |||
cutted_sentence = [] | |||
if len(sent_no_space) > max_sample_length: | |||
@@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader): | |||
return dataset | |||
class ConlluCWSReader(object): | |||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
class ConllCWSReader(object): | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
""" | |||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
@@ -150,10 +171,10 @@ class ConlluCWSReader(object): | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
res = self.get_char_lst(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
@@ -163,7 +184,7 @@ class ConlluCWSReader(object): | |||
return ds | |||
def get_one(self, sample): | |||
def get_char_lst(self, sample): | |||
if len(sample)==0: | |||
return None | |||
text = [] | |||
@@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||
class CWSBiLSTMEncoder(BaseModel): | |||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1): | |||
super().__init__() | |||
self.input_size = 0 | |||
@@ -68,6 +68,7 @@ class CWSBiLSTMEncoder(BaseModel): | |||
if not bigrams is None: | |||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | |||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | |||
x_tensor = self.embedding_drop(x_tensor) | |||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | |||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | |||
@@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel): | |||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||
class CWSBiLSTMCRF(BaseModel): | |||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): | |||
""" | |||
默认使用BMES的标注方式 | |||
:param vocab_num: | |||
:param embed_dim: | |||
:param bigram_vocab_num: | |||
:param bigram_embed_dim: | |||
:param num_bigram_per_char: | |||
:param hidden_size: | |||
:param bidirectional: | |||
:param embed_drop_p: | |||
:param num_layers: | |||
:param tag_size: | |||
""" | |||
super(CWSBiLSTMCRF, self).__init__() | |||
self.tag_size = tag_size | |||
@@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel): | |||
size_layer = [hidden_size, 200, tag_size] | |||
self.decoder_model = MLP(size_layer) | |||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||
allowed_transitions=allowed_trans) | |||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||
def forward(self, chars, target, seq_lens, bigrams=None): | |||
device = self.parameters().__next__().device | |||
chars = chars.to(device).long() | |||
if not bigrams is None: | |||
@@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel): | |||
masks = seq_lens_to_mask(seq_lens) | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
feats = self.decoder_model(feats) | |||
losses = self.crf(feats, tags, masks) | |||
losses = self.crf(feats, target, masks) | |||
pred_dict = {} | |||
pred_dict['seq_lens'] = seq_lens | |||
@@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel): | |||
feats = self.decoder_model(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
return {'pred_tags': probs} | |||
return {'pred': probs} | |||
@@ -1,17 +1,18 @@ | |||
import re | |||
from fastNLP.core.field import SeqLabelField | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.processor import Processor | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | |||
class SpeicalSpanProcessor(Processor): | |||
# 这个类会将句子中的special span转换为对应的内容。 | |||
""" | |||
将DataSet中field_name使用span_converter替换掉。 | |||
""" | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | |||
@@ -20,11 +21,12 @@ class SpeicalSpanProcessor(Processor): | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
def inner_proc(ins): | |||
sentence = ins[self.field_name] | |||
for span_converter in self.span_converters: | |||
sentence = span_converter.find_certain_span_and_replace(sentence) | |||
ins[self.new_added_field_name] = sentence | |||
return sentence | |||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
return dataset | |||
@@ -34,17 +36,22 @@ class SpeicalSpanProcessor(Processor): | |||
self.span_converters.append(converter) | |||
class CWSCharSegProcessor(Processor): | |||
""" | |||
将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学', | |||
' ', 'f', 'u', ...] | |||
""" | |||
def __init__(self, field_name, new_added_field_name): | |||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
def inner_proc(ins): | |||
sentence = ins[self.field_name] | |||
chars = self._split_sent_into_chars(sentence) | |||
ins[self.new_added_field_name] = chars | |||
return chars | |||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
return dataset | |||
@@ -73,6 +80,10 @@ class CWSCharSegProcessor(Processor): | |||
class CWSTagProcessor(Processor): | |||
""" | |||
为分词生成tag。该class为Base class。 | |||
""" | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | |||
@@ -107,18 +118,22 @@ class CWSTagProcessor(Processor): | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
def inner_proc(ins): | |||
sentence = ins[self.field_name] | |||
tag_list = self._generate_tag(sentence) | |||
ins[self.new_added_field_name] = tag_list | |||
dataset.set_target(**{self.new_added_field_name:True}) | |||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||
return tag_list | |||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
dataset.set_target(self.new_added_field_name) | |||
return dataset | |||
def _tags_from_word_len(self, word_len): | |||
raise NotImplementedError | |||
class CWSBMESTagProcessor(CWSTagProcessor): | |||
""" | |||
通过DataSet中的field_name这个field生成相应的BMES的tag。 | |||
""" | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||
@@ -137,6 +152,10 @@ class CWSBMESTagProcessor(CWSTagProcessor): | |||
return tag_list | |||
class CWSSegAppTagProcessor(CWSTagProcessor): | |||
""" | |||
通过DataSet中的field_name这个field生成相应的SegApp的tag。 | |||
""" | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | |||
@@ -151,6 +170,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor): | |||
class BigramProcessor(Processor): | |||
""" | |||
这是生成bigram的基类。 | |||
""" | |||
def __init__(self, field_name, new_added_fielf_name=None): | |||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||
@@ -158,22 +181,31 @@ class BigramProcessor(Processor): | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
def inner_proc(ins): | |||
characters = ins[self.field_name] | |||
bigrams = self._generate_bigram(characters) | |||
ins[self.new_added_field_name] = bigrams | |||
return bigrams | |||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||
return dataset | |||
def _generate_bigram(self, characters): | |||
pass | |||
class Pre2Post2BigramProcessor(BigramProcessor): | |||
def __init__(self, field_name, new_added_fielf_name=None): | |||
""" | |||
该bigram processor生成bigram的方式如下 | |||
原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为 | |||
[L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....] | |||
即每个汉字,会有八个bigram, 对于上例中'a'的bigram为 | |||
['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac'] | |||
返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。 | |||
""" | |||
def __init__(self, field_name, new_added_field_name=None): | |||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||
super(BigramProcessor, self).__init__(field_name, new_added_field_name) | |||
def _generate_bigram(self, characters): | |||
bigrams = [] | |||
@@ -197,20 +229,107 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | |||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||
# Processor了 | |||
# TODO 如何将建立vocab和index这两步统一了? | |||
class VocabIndexerProcessor(Processor): | |||
""" | |||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||
new_added_field_name, 则覆盖原有的field_name. | |||
""" | |||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||
verbose=1, is_input=True): | |||
""" | |||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||
:param bool is_input: | |||
""" | |||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||
self.min_freq = min_freq | |||
self.max_size = max_size | |||
self.verbose =verbose | |||
self.is_input = is_input | |||
def construct_vocab(self, *datasets): | |||
""" | |||
使用传入的DataSet创建vocabulary | |||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||
:return: | |||
""" | |||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||
for dataset in datasets: | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
self.vocab.build_vocab() | |||
if self.verbose: | |||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||
def process(self, *datasets, only_index_dataset=None): | |||
""" | |||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||
后,则会index datasets与only_index_dataset。 | |||
:param datasets: DataSet类型的数据 | |||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||
:return: | |||
""" | |||
if len(datasets)==0 and not hasattr(self,'vocab'): | |||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||
if not hasattr(self, 'vocab'): | |||
self.construct_vocab(*datasets) | |||
else: | |||
if self.verbose: | |||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||
to_index_datasets = [] | |||
if len(datasets)!=0: | |||
for dataset in datasets: | |||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
to_index_datasets.append(dataset) | |||
if not (only_index_dataset is None): | |||
if isinstance(only_index_dataset, list): | |||
for dataset in only_index_dataset: | |||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
to_index_datasets.append(dataset) | |||
elif isinstance(only_index_dataset, DataSet): | |||
to_index_datasets.append(only_index_dataset) | |||
else: | |||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||
for dataset in to_index_datasets: | |||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||
# 只返回一个,infer时为了跟其他processor保持一致 | |||
if len(to_index_datasets) == 1: | |||
return to_index_datasets[0] | |||
def set_vocab(self, vocab): | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||
self.vocab = vocab | |||
def delete_vocab(self): | |||
del self.vocab | |||
def get_vocab_size(self): | |||
return len(self.vocab) | |||
class VocabProcessor(Processor): | |||
def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||
def __init__(self, field_name, min_freq=1, max_size=None): | |||
super(VocabProcessor, self).__init__(field_name, None) | |||
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||
def process(self, *datasets): | |||
for dataset in datasets: | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
tokens = ins[self.field_name] | |||
self.vocab.update(tokens) | |||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||
def get_vocab(self): | |||
self.vocab.build_vocab() | |||
@@ -220,19 +339,6 @@ class VocabProcessor(Processor): | |||
return len(self.vocab) | |||
class SeqLenProcessor(Processor): | |||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
length = len(ins[self.field_name]) | |||
ins[self.new_added_field_name] = length | |||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||
return dataset | |||
class SegApp2OutputProcessor(Processor): | |||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
super(SegApp2OutputProcessor, self).__init__(None, None) | |||
@@ -258,7 +364,32 @@ class SegApp2OutputProcessor(Processor): | |||
class BMES2OutputProcessor(Processor): | |||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
""" | |||
按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
| | next_B | next_M | next_E | next_S | end | | |||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||
举例: | |||
prediction为BSEMS,会被认为是SSSSS. | |||
""" | |||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output', | |||
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): | |||
""" | |||
:param chars_field_name: character所对应的field | |||
:param tag_field_name: 预测对应的field | |||
:param new_added_field_name: 转换后的内容所在field | |||
:param b_idx: int, Begin标签所对应的tag idx. | |||
:param m_idx: int, Middle标签所对应的tag idx. | |||
:param e_idx: int, End标签所对应的tag idx. | |||
:param s_idx: int, Single标签所对应的tag idx | |||
""" | |||
super(BMES2OutputProcessor, self).__init__(None, None) | |||
self.chars_field_name = chars_field_name | |||
@@ -266,19 +397,55 @@ class BMES2OutputProcessor(Processor): | |||
self.new_added_field_name = new_added_field_name | |||
self.b_idx = b_idx | |||
self.m_idx = m_idx | |||
self.e_idx = e_idx | |||
self.s_idx = s_idx | |||
# 还原init处介绍的矩阵 | |||
self._valida_matrix = { | |||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||
} | |||
def _validate_tags(self, tags): | |||
""" | |||
给定一个tag的List,返回合法tag | |||
:param tags: Tensor, shape: (seq_len, ) | |||
:return: 返回修改为合法tag的list | |||
""" | |||
assert len(tags)!=0 | |||
padded_tags = [-1, *tags, -1] | |||
for idx in range(len(padded_tags)-1): | |||
cur_tag = padded_tags[idx] | |||
if cur_tag not in self._valida_matrix: | |||
cur_tag = self.s_idx | |||
if padded_tags[idx+1] not in self._valida_matrix: | |||
padded_tags[idx+1] = self.s_idx | |||
next_tag = padded_tags[idx+1] | |||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||
if shift_tag[0]!=-1: | |||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||
return padded_tags[1:-1] | |||
def process(self, dataset): | |||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
for ins in dataset: | |||
def inner_proc(ins): | |||
pred_tags = ins[self.tag_field_name] | |||
pred_tags = self._validate_tags(pred_tags) | |||
chars = ins[self.chars_field_name] | |||
words = [] | |||
start_idx = 0 | |||
for idx, tag in enumerate(pred_tags): | |||
if tag==3: | |||
# 当前没有考虑将原文替换回去 | |||
if tag==self.s_idx: | |||
words.extend(chars[start_idx:idx+1]) | |||
start_idx = idx + 1 | |||
elif tag==2: | |||
elif tag==self.e_idx: | |||
words.append(''.join(chars[start_idx:idx+1])) | |||
start_idx = idx + 1 | |||
ins[self.new_added_field_name] = ' '.join(words) | |||
return ' '.join(words) | |||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) |
@@ -0,0 +1,229 @@ | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||
from fastNLP.api.processor import SeqLenProcessor | |||
from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor | |||
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor | |||
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | |||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF | |||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||
ds_name = 'msr' | |||
tr_filename = '/home/hyan/ctb3/train.conllx' | |||
dev_filename = '/home/hyan/ctb3/dev.conllx' | |||
reader = ConllCWSReader() | |||
tr_dataset = reader.load(tr_filename, cut_long_sent=True) | |||
dev_dataset = reader.load(dev_filename) | |||
print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset))) | |||
# 1. 准备processor | |||
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | |||
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst') | |||
tag_proc = CWSBMESTagProcessor('raw_sentence', 'target') | |||
bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst') | |||
char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars') | |||
bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4) | |||
seq_len_proc = SeqLenProcessor('chars') | |||
# 2. 使用processor | |||
fs2hs_proc(tr_dataset) | |||
char_proc(tr_dataset) | |||
tag_proc(tr_dataset) | |||
bigram_proc(tr_dataset) | |||
char_vocab_proc(tr_dataset) | |||
bigram_vocab_proc(tr_dataset) | |||
seq_len_proc(tr_dataset) | |||
# 2.1 处理dev_dataset | |||
fs2hs_proc(dev_dataset) | |||
char_proc(dev_dataset) | |||
tag_proc(dev_dataset) | |||
bigram_proc(dev_dataset) | |||
char_vocab_proc(dev_dataset) | |||
bigram_vocab_proc(dev_dataset) | |||
seq_len_proc(dev_dataset) | |||
dev_dataset.set_input('chars', 'bigrams', 'target') | |||
tr_dataset.set_input('chars', 'bigrams', 'target') | |||
dev_dataset.set_target('seq_lens') | |||
tr_dataset.set_target('seq_lens') | |||
print("Finish preparing data.") | |||
# 3. 得到数据集可以用于训练了 | |||
# TODO pretrain的embedding是怎么解决的? | |||
import torch | |||
from torch import optim | |||
tag_size = tag_proc.tag_size | |||
cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), | |||
bigram_embed_dim=100, num_bigram_per_char=8, | |||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, | |||
num_layers=1, tag_size=tag_size) | |||
cws_model.cuda() | |||
num_epochs = 5 | |||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.core.sampler import BucketSampler | |||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||
metric = BMESF1PreRecMetric(target='tags') | |||
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, | |||
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, | |||
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) | |||
trainer.train() | |||
exit(0) | |||
# | |||
# print_every = 50 | |||
# batch_size = 32 | |||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||
# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||
# num_batch_per_epoch = len(tr_dataset) // batch_size | |||
# best_f1 = 0 | |||
# best_epoch = 0 | |||
# for num_epoch in range(num_epochs): | |||
# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) | |||
# sys.stdout.flush() | |||
# avg_loss = 0 | |||
# with tqdm(total=num_batch_per_epoch, leave=True) as pbar: | |||
# pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) | |||
# cws_model.train() | |||
# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||
# optimizer.zero_grad() | |||
# | |||
# tags = batch_y['tags'].long() | |||
# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size | |||
# | |||
# seq_lens = pred_dict['seq_lens'] | |||
# masks = seq_lens_to_mask(seq_lens).float() | |||
# tags = tags.to(seq_lens.device) | |||
# | |||
# loss = pred_dict['loss'] | |||
# | |||
# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||
# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||
# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||
# | |||
# avg_loss += loss.item() | |||
# | |||
# loss.backward() | |||
# for group in optimizer.param_groups: | |||
# for param in group['params']: | |||
# param.grad.clamp_(-5, 5) | |||
# | |||
# optimizer.step() | |||
# | |||
# if batch_idx % print_every == 0: | |||
# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||
# avg_loss = 0 | |||
# pbar.update(print_every) | |||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||
# # 验证集 | |||
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||
# pre*100, | |||
# rec*100)) | |||
# if best_f1<f1: | |||
# best_f1 = f1 | |||
# # 缓存最佳的parameter,可能之后会用于保存 | |||
# best_state_dict = { | |||
# key:value.clone() for key, value in | |||
# cws_model.state_dict().items() | |||
# } | |||
# best_epoch = num_epoch | |||
# | |||
# cws_model.load_state_dict(best_state_dict) | |||
# 4. 组装需要存下的内容 | |||
pp = Pipeline() | |||
pp.add_processor(fs2hs_proc) | |||
# pp.add_processor(sp_proc) | |||
pp.add_processor(char_proc) | |||
pp.add_processor(tag_proc) | |||
pp.add_processor(bigram_proc) | |||
pp.add_processor(char_vocab_proc) | |||
pp.add_processor(bigram_vocab_proc) | |||
pp.add_processor(seq_len_proc) | |||
# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||
te_filename = '/home/hyan/ctb3/test.conllx' | |||
te_dataset = reader.load(te_filename) | |||
pp(te_dataset) | |||
from fastNLP.core.tester import Tester | |||
tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False, | |||
verbose=1) | |||
# | |||
# batch_size = 64 | |||
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||
# pre * 100, | |||
# rec * 100)) | |||
# TODO 这里貌似需要区分test pipeline与infer pipeline | |||
test_context_dict = {'pipeline': pp, | |||
'model': cws_model} | |||
torch.save(test_context_dict, 'models/test_context_crf.pkl') | |||
# 5. dev的pp | |||
# 4. 组装需要存下的内容 | |||
from fastNLP.api.processor import ModelProcessor | |||
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor | |||
model_proc = ModelProcessor(cws_model) | |||
output_proc = BMES2OutputProcessor() | |||
pp = Pipeline() | |||
pp.add_processor(fs2hs_proc) | |||
# pp.add_processor(sp_proc) | |||
pp.add_processor(char_proc) | |||
pp.add_processor(bigram_proc) | |||
pp.add_processor(char_vocab_proc) | |||
pp.add_processor(bigram_vocab_proc) | |||
pp.add_processor(seq_len_proc) | |||
pp.add_processor(model_proc) | |||
pp.add_processor(output_proc) | |||
# TODO 这里貌似需要区分test pipeline与infer pipeline | |||
infer_context_dict = {'pipeline': pp} | |||
# torch.save(infer_context_dict, 'models/cws_crf.pkl') | |||
# TODO 还需要考虑如何替换回原文的问题? | |||
# 1. 不需要将特殊tag替换 | |||
# 2. 需要将特殊tag替换回去 |
@@ -4,6 +4,7 @@ from collections import Counter | |||
from fastNLP.api.processor import Processor | |||
from fastNLP.core.dataset import DataSet | |||
class CombineWordAndPosProcessor(Processor): | |||
def __init__(self, word_field_name, pos_field_name): | |||
super(CombineWordAndPosProcessor, self).__init__(None, None) | |||
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||
return dataset | |||
class PosOutputStrProcessor(Processor): | |||
def __init__(self, word_field_name, pos_field_name): | |||
super(PosOutputStrProcessor, self).__init__(None, None) |
@@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200): | |||
return cutted_sentence | |||
class ConlluPOSReader(object): | |||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
class ConllPOSReader(object): | |||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||
def __init__(self): | |||
pass | |||
@@ -70,6 +70,70 @@ class ConlluPOSReader(object): | |||
return ds | |||
class ZhConllPOSReader(object): | |||
# 中文colln格式reader | |||
def __init__(self): | |||
pass | |||
def load(self, path): | |||
""" | |||
返回的DataSet, 包含以下的field | |||
words:list of str, | |||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if res is None: | |||
continue | |||
char_seq = [] | |||
pos_seq = [] | |||
for word, tag in zip(res[0], res[1]): | |||
char_seq.extend(list(word)) | |||
if len(word)==1: | |||
pos_seq.append('S-{}'.format(tag)) | |||
elif len(word)>1: | |||
pos_seq.append('B-{}'.format(tag)) | |||
for _ in range(len(word)-2): | |||
pos_seq.append('M-{}'.format(tag)) | |||
pos_seq.append('E-{}'.format(tag)) | |||
else: | |||
raise ValueError("Zero length of word detected.") | |||
ds.append(Instance(words=char_seq, | |||
tag=pos_seq)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample)==0: | |||
return None | |||
@@ -84,6 +148,6 @@ class ConlluPOSReader(object): | |||
return text, pos_tags | |||
if __name__ == '__main__': | |||
reader = ConlluPOSReader() | |||
reader = ZhConllPOSReader() | |||
d = reader.load('/home/hyan/train.conllx') | |||
print('reader') | |||
print(d) |
@@ -0,0 +1,82 @@ | |||
import os | |||
import sys | |||
import torch | |||
# in order to run fastNLP without installation | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import SeqLenProcessor | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||
cfgfile = './pos_tag.cfg' | |||
pickle_path = "save" | |||
def train(): | |||
# load config | |||
train_param = ConfigSection() | |||
model_param = ConfigSection() | |||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||
print("config loaded") | |||
# Data Loader | |||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||
print(dataset) | |||
print("dataset transformed") | |||
dataset.rename_field("tag", "truth") | |||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||
tag_proc = VocabIndexerProcessor("truth") | |||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||
vocab_proc(dataset) | |||
tag_proc(dataset) | |||
seq_len_proc(dataset) | |||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
dataset.set_target("truth", "word_seq_origin_len") | |||
print("processors defined") | |||
# dataset.set_is_target(tag_ids=True) | |||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||
# define a model | |||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||
# call trainer to train | |||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dataset, metric_key="f", | |||
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") | |||
trainer.train() | |||
# save model & pipeline | |||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
torch.save(save_dict, "model_pp.pkl") | |||
print("pipeline saved") | |||
def infer(): | |||
pass | |||
if __name__ == "__main__": | |||
train() |
@@ -4,6 +4,7 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||
from fastNLP.core.metrics import pred_topk, accuracy_topk | |||
@@ -132,6 +133,234 @@ class TestAccuracyMetric(unittest.TestCase): | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
class SpanF1PreRecMetric(unittest.TestCase): | |||
def test_case1(self): | |||
from fastNLP.core.metrics import bmes_tag_to_spans | |||
from fastNLP.core.metrics import bio_tag_to_spans | |||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||
expect_bmes_res = set() | |||
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), | |||
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) | |||
expect_bio_res = set() | |||
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), | |||
('6', (4, 4)), ('7', (6, 6))]) | |||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||
# for i in range(1000): | |||
# strs = list(map(str, np.random.randint(100, size=1000))) | |||
# bmes = list('bmes'.upper()) | |||
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] | |||
# bio = list('bio'.upper()) | |||
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
def test_case2(self): | |||
# 测试不带label的 | |||
from fastNLP.core.metrics import bmes_tag_to_spans | |||
from fastNLP.core.metrics import bio_tag_to_spans | |||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||
expect_bmes_res = set() | |||
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) | |||
expect_bio_res = set() | |||
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) | |||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||
# for i in range(1000): | |||
# bmes = list('bmes'.upper()) | |||
# bmes_strs = np.random.choice(bmes, size=1000) | |||
# bio = list('bio'.upper()) | |||
# bio_strs = np.random.choice(bio, size=100) | |||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
def tese_case3(self): | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from collections import Counter | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
# 与allennlp测试能否正确计算f metric | |||
# | |||
def generate_allen_tags(encoding_type, number_labels=4): | |||
vocab = {} | |||
for i in range(number_labels): | |||
label = str(i) | |||
for tag in encoding_type: | |||
if tag == 'O': | |||
if tag not in vocab: | |||
vocab['O'] = len(vocab) + 1 | |||
continue | |||
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||
return vocab | |||
number_labels = 4 | |||
# bio tag | |||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
bio_sequence = torch.FloatTensor( | |||
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||
0.0470, 0.0971], | |||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
0.7987, -0.3970], | |||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
0.6880, 1.4348], | |||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
-1.6876, -0.8917], | |||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
1.4217, 0.2622]], | |||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
1.3592, -0.8973], | |||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
-0.4025, -0.3417], | |||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
0.2861, -0.3966], | |||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
0.0213, 1.4777], | |||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
1.3024, 0.2001]]] | |||
) | |||
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||
[5., 6., 8., 6., 0.]]) | |||
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) | |||
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, | |||
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, | |||
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||
'f': 0.12499999999994846} | |||
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||
#bmes tag | |||
bmes_sequence = torch.FloatTensor( | |||
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||
-0.3505, -0.6002], | |||
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, | |||
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, | |||
0.4439, 0.2514], | |||
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, | |||
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, | |||
0.5802, -0.0516], | |||
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, | |||
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, | |||
-0.9242, -0.0333], | |||
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||
-0.3779, -0.3195]], | |||
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||
-0.1103, 0.4417], | |||
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, | |||
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, | |||
-0.6386, 0.2797], | |||
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, | |||
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, | |||
1.1237, -1.5385], | |||
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, | |||
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, | |||
-0.0591, -0.2461], | |||
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, | |||
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||
-0.7344, -1.2046]]] | |||
) | |||
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||
[ 6., 15., 6., 15., 5.]]) | |||
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||
'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||
# from allennlp.training.metrics import SpanBasedF1Measure | |||
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, | |||
# non_padded_namespaces=['tags']) | |||
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') | |||
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) | |||
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) | |||
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) | |||
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
# | |||
# def convert_allen_res_to_fastnlp_res(metric_result): | |||
# allen_result = {} | |||
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} | |||
# for key, value in metric_result.items(): | |||
# if key in key_map: | |||
# key = key_map[key] | |||
# else: | |||
# label = key.split('-')[-1] | |||
# if key.startswith('f1'): | |||
# key = 'f-{}'.format(label) | |||
# else: | |||
# key = '{}-{}'.format(key[:3], label) | |||
# allen_result[key] = value | |||
# return allen_result | |||
# | |||
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) | |||
# # print(fastnlp_bio_metric.get_metric()) | |||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), | |||
# fastnlp_bio_metric.get_metric()) | |||
# | |||
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) | |||
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') | |||
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) | |||
# bmes_target = torch.randint(4 * number_labels, size=(2, 20)) | |||
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) | |||
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
# | |||
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) | |||
# # print(fastnlp_bmes_metric.get_metric()) | |||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||
# fastnlp_bmes_metric.get_metric()) | |||
class TestBMESF1PreRecMetric(unittest.TestCase): | |||
def test_case1(self): | |||
seq_lens = torch.LongTensor([4, 2]) | |||
pred = torch.randn(2, 4, 4) | |||
target = torch.LongTensor([[0, 1, 2, 3], | |||
[3, 3, 0, 0]]) | |||
pred_dict = {'pred': pred} | |||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||
metric = BMESF1PreRecMetric() | |||
metric(pred_dict, target_dict) | |||
metric.get_metric() | |||
def test_case2(self): | |||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||
seq_lens = torch.LongTensor([4, 2]) | |||
target = torch.LongTensor([[0, 1, 2, 3], | |||
[3, 3, 0, 0]]) | |||
pred_dict = {'pred': target} | |||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||
metric = BMESF1PreRecMetric() | |||
metric(pred_dict, target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0}) | |||
class TestUsefulFunctions(unittest.TestCase): | |||
# 测试metrics.py中一些看上去挺有用的函数 | |||
@@ -0,0 +1,104 @@ | |||
import unittest | |||
class TestCRF(unittest.TestCase): | |||
def test_case1(self): | |||
# 检查allowed_transitions()能否正确使用 | |||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
allowed_transitions(id2label) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
for tag in 'BI': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx:label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
for tag in 'BMES': | |||
labels.append('{}-{}'.format(tag, label)) | |||
id2label = {idx: label for idx, label in enumerate(labels)} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
def test_case2(self): | |||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
pass | |||
# import torch | |||
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||
# | |||
# labels = ['O'] | |||
# for label in ['X', 'Y']: | |||
# for tag in 'BI': | |||
# labels.append('{}-{}'.format(tag, label)) | |||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||
# num_tags = len(id2label) | |||
# | |||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||
# include_start_end_transitions=False) | |||
# batch_size = 3 | |||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||
# trans_m = allen_CRF.transitions | |||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||
# seq_lens[-1] = 20 | |||
# mask = seq_len_to_byte_mask(seq_lens) | |||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||
# | |||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | |||
# fast_CRF.trans_m = trans_m | |||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||
# # score equal | |||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||
# # seq equal | |||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||
# | |||
# | |||
# labels = [] | |||
# for label in ['X', 'Y']: | |||
# for tag in 'BMES': | |||
# labels.append('{}-{}'.format(tag, label)) | |||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||
# num_tags = len(id2label) | |||
# | |||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||
# include_start_end_transitions=False) | |||
# batch_size = 3 | |||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||
# trans_m = allen_CRF.transitions | |||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||
# seq_lens[-1] = 20 | |||
# mask = seq_len_to_byte_mask(seq_lens) | |||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||
# | |||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
# encoding_type='BMES')) | |||
# fast_CRF.trans_m = trans_m | |||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||
# # score equal | |||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||
# # seq equal | |||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||