Browse Source

Merge branch 'dev' into local-dev

# Conflicts:
#	fastNLP/core/dataset.py
tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
1f4d784068
24 changed files with 1757 additions and 263 deletions
  1. +43
    -22
      fastNLP/api/api.py
  2. +56
    -10
      fastNLP/api/processor.py
  3. +4
    -4
      fastNLP/core/batch.py
  4. +2
    -2
      fastNLP/core/dataset.py
  5. +1
    -1
      fastNLP/core/losses.py
  6. +366
    -9
      fastNLP/core/metrics.py
  7. +5
    -8
      fastNLP/core/trainer.py
  8. +5
    -1
      fastNLP/io/base_loader.py
  9. +22
    -27
      fastNLP/io/config_io.py
  10. +85
    -75
      fastNLP/io/dataset_loader.py
  11. +5
    -4
      fastNLP/io/embed_loader.py
  12. +5
    -4
      fastNLP/io/logger.py
  13. +9
    -9
      fastNLP/io/model_io.py
  14. +13
    -6
      fastNLP/models/sequence_modeling.py
  15. +162
    -22
      fastNLP/modules/decoder/CRF.py
  16. +26
    -5
      reproduction/chinese_word_segment/cws_io/cws_reader.py
  17. +23
    -6
      reproduction/chinese_word_segment/models/cws_model.py
  18. +211
    -44
      reproduction/chinese_word_segment/process/cws_processor.py
  19. +229
    -0
      reproduction/chinese_word_segment/train_context.py
  20. +2
    -0
      reproduction/pos_tag_model/pos_processor.py
  21. +68
    -4
      reproduction/pos_tag_model/pos_reader.py
  22. +82
    -0
      reproduction/pos_tag_model/train_pos_tag.py
  23. +229
    -0
      test/core/test_metrics.py
  24. +104
    -0
      test/modules/decoder/test_CRF.py

+ 43
- 22
fastNLP/api/api.py View File

@@ -9,16 +9,15 @@ from fastNLP.core.dataset import DataSet

from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_reader import ConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SeqLabelEvaluator2
from fastNLP.core.tester import Tester


# TODO add pretrain urls
model_urls = {
@@ -29,6 +28,7 @@ model_urls = {
class API:
def __init__(self):
self.pipeline = None
self._dict = None

def predict(self, *args, **kwargs):
raise NotImplementedError
@@ -38,8 +38,8 @@ class API:
_dict = torch.load(path, map_location='cpu')
else:
_dict = load_url(path, map_location='cpu')
self.pipeline = _dict['pipeline']
self._dict = _dict
self.pipeline = _dict['pipeline']
for processor in self.pipeline.pipeline:
if isinstance(processor, ModelProcessor):
processor.set_model_device(device)
@@ -48,8 +48,10 @@ class API:
class POS(API):
"""FastNLP API for Part-Of-Speech tagging.

"""
:param str model_path: the path to the model.
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.

"""
def __init__(self, model_path=None, device='cpu'):
super(POS, self).__init__()
if model_path is None:
@@ -75,12 +77,28 @@ class POS(API):

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('words', sentence_list)
dataset.add_field("words", sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

output = dataset['word_pos_output'].content
def decode_tags(ins):
pred_tags = ins["tag"]
chars = ins["words"]
words = []
start_idx = 0
for idx, tag in enumerate(pred_tags):
if tag[0] == "S":
words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
start_idx = idx + 1
elif tag[0] == "E":
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
start_idx = idx + 1
return words

dataset.apply(decode_tags, new_field_name="tag_output")

output = dataset.field_arrays["tag_output"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
@@ -95,9 +113,10 @@ class POS(API):
pipeline.append(tag_proc)
pp = Pipeline(pipeline)

reader = ConlluPOSReader()
reader = ConllPOSReader()
te_dataset = reader.load(filepath)

"""
evaluator = SeqLabelEvaluator2('word_seq_origin_len')
end_tagidx_set = set()
tag_proc.vocab.build_vocab()
@@ -108,15 +127,16 @@ class POS(API):
end_tagidx_set.add(value)
evaluator.end_tagidx_set = end_tagidx_set

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator}

pp(te_dataset)
te_dataset.set_target(truth=True)

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator,
"model": model, "data": te_dataset}

tester = Tester(**default_valid_args)

test_result = tester.test(model, te_dataset)
test_result = tester.test()

f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2)
@@ -124,6 +144,7 @@ class POS(API):
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec
"""


class CWS(API):
@@ -168,7 +189,7 @@ class CWS(API):
pipeline.insert(1, tag_proc)
pp = Pipeline(pipeline)

reader = ConlluCWSReader()
reader = ConllCWSReader()

# te_filename = '/home/hyan/ctb3/test.conllx'
te_dataset = reader.load(filepath)
@@ -290,13 +311,13 @@ class Analyzer:


if __name__ == "__main__":
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
# pos = POS(device='cpu')
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
pos = POS(pos_model_path, device='cpu')
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
# print(pos.predict(s))
print(pos.predict(s))

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
# cws = CWS(device='cpu')
@@ -306,9 +327,9 @@ if __name__ == "__main__":
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll'))
# print(cws.predict(s))

parser = Parser(device='cpu')
# parser = Parser(device='cpu')
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(parser.predict(s))
# print(parser.predict(s))

+ 56
- 10
fastNLP/api/processor.py View File

@@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary

class Processor(object):
def __init__(self, field_name, new_added_field_name):
"""

:param field_name: 处理哪个field
:param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field
"""
self.field_name = field_name
if new_added_field_name is None:
self.new_added_field_name = field_name
@@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor):


class PreAppendProcessor(Processor):
"""
向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为
[data] + instance[field_name]

"""
def __init__(self, data, field_name, new_added_field_name=None):
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name)
self.data = data
@@ -102,6 +112,10 @@ class PreAppendProcessor(Processor):


class SliceProcessor(Processor):
"""
从某个field中只取部分内容。等价于instance[field_name][start:end:step]

"""
def __init__(self, start, end, step, field_name, new_added_field_name=None):
super(SliceProcessor, self).__init__(field_name, new_added_field_name)
for o in (start, end, step):
@@ -114,7 +128,17 @@ class SliceProcessor(Processor):


class Num2TagProcessor(Processor):
"""
将一句话中的数字转换为某个tag。

"""
def __init__(self, tag, field_name, new_added_field_name=None):
"""

:param tag: str, 将数字转换为该tag
:param field_name:
:param new_added_field_name:
"""
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name)
self.tag = tag
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)'
@@ -135,6 +159,10 @@ class Num2TagProcessor(Processor):


class IndexerProcessor(Processor):
"""
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如
['我', '是', xxx]
"""
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):

assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@@ -163,19 +191,19 @@ class IndexerProcessor(Processor):


class VocabProcessor(Processor):
"""Build vocabulary with a field in the data set.
"""
传入若干个DataSet以建立vocabulary。

"""

def __init__(self, field_name):
def __init__(self, field_name, min_freq=1, max_size=None):
super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size)

def process(self, *datasets):
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
self.vocab.update(ins[self.field_name])
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))

def get_vocab(self):
self.vocab.build_vocab()
@@ -183,6 +211,10 @@ class VocabProcessor(Processor):


class SeqLenProcessor(Processor):
"""
根据某个field新增一个sequence length的field。取该field的第一维

"""
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input
@@ -195,10 +227,15 @@ class SeqLenProcessor(Processor):
return dataset


from fastNLP.core.utils import _build_args

class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
"""
迭代模型并将结果的padding drop掉
传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward.
model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与
(Batch_size, 1),则使用seqence length这个field进行unpad
TODO 这个类需要删除对seq_lens的依赖。

:param seq_len_field_name:
:param batch_size:
@@ -211,13 +248,18 @@ class ModelProcessor(Processor):
def process(self, dataset):
self.model.eval()
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler())

batch_output = defaultdict(list)
if hasattr(self.model, "predict"):
predict_func = self.model.predict
else:
predict_func = self.model.forward
with torch.no_grad():
for batch_x, _ in data_iterator:
prediction = self.model.predict(**batch_x)
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist()
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)
seq_lens = batch_x[self.seq_len_field_name].tolist()

for key, value in prediction.items():
tmp_batch = []
@@ -246,6 +288,10 @@ class ModelProcessor(Processor):


class Index2WordProcessor(Processor):
"""
将DataSet中某个为index的field根据vocab转换为str

"""
def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab
@@ -266,5 +312,5 @@ class SetIsTargetProcessor(Processor):
def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
set_dict.update(self.field_dict)
dataset.set_target(**set_dict)
dataset.set_target(*set_dict.keys())
return dataset

+ 4
- 4
fastNLP/core/batch.py View File

@@ -10,10 +10,10 @@ class Batch(object):
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
# ...

:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors.
:param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.

"""



+ 2
- 2
fastNLP/core/dataset.py View File

@@ -254,6 +254,8 @@ class DataSet(object):
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

extra_param = {}
if 'is_input' in kwargs:
@@ -261,8 +263,6 @@ class DataSet(object):
if 'is_target' in kwargs:
extra_param['is_target'] = kwargs['is_target']
if new_field_name is not None:
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]


+ 1
- 1
fastNLP/core/losses.py View File

@@ -250,7 +250,7 @@ class LossInForward(LossBase):

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}")
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}")
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")

return loss


+ 366
- 9
fastNLP/core/metrics.py View File

@@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.vocabulary import Vocabulary


class MetricBase(object):
@@ -80,11 +81,6 @@ class MetricBase(object):
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
# f"positional argument.).")

def get_metric(self, reset=True):
raise NotImplemented

@@ -108,10 +104,9 @@ class MetricBase(object):

This method will call self.evaluate method.
Before calling self.evaluate, it will first check the validity of output_dict, target_dict
(1) whether self.evaluate has varargs, which is not supported.
(2) whether params needed by self.evaluate is not included in output_dict,target_dict.
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
(1) whether params needed by self.evaluate is not included in output_dict,target_dict.
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.)
@@ -299,6 +294,368 @@ class AccuracyMetric(MetricBase):
self.total = 0
return evaluate_result

def bmes_tag_to_spans(tags, ignore_labels=None):
"""

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
spans[-1][1][1] = idx
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]))
for span in spans
if span[0] not in ignore_labels
]

def bio_tag_to_spans(tags, ignore_labels=None):
"""

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bio_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b':
spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]:
spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
pass
else:
spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]))
for span in spans
if span[0] not in ignore_labels
]


class SpanFPreRecMetric(MetricBase):
"""
在序列标注问题中,以span的方式计算F, pre, rec.
最后得到的metric结果为
{
'f': xxx, # 这里使用f考虑以后可以计算f_beta值
'pre': xxx,
'rec':xxx
}
若only_gross=False, 即还会返回各个label的metric统计值
{
'f': xxx,
'pre': xxx,
'rec':xxx,
'f-label': xxx,
'pre-label': xxx,
'rec-label':xxx,
...
}

"""
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
"""

:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。
:param encoding_type: str, 目前支持bio, bmes
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
label的f1, pre, rec
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
"""
encoding_type = encoding_type.lower()
if encoding_type not in ('bio', 'bmes'):
raise ValueError("Only support 'bio' or 'bmes' type.")
if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))

self.encoding_type = encoding_type
if self.encoding_type == 'bmes':
self.tag_to_span_func = bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = bio_tag_to_spans
self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta**2
self.only_gross = only_gross

super().__init__()
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)

self.tag_vocab = tag_vocab

self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

def evaluate(self, pred, target, seq_lens):
"""
A lot of design idea comes from allennlp's measure
:param pred:
:param target:
:param seq_lens:
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
num_classes = pred.size(-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

batch_size = pred.size(0)
for i in range(batch_size):
pred_tags = pred[i, :seq_lens[i]].tolist()
gold_tags = target[i, :seq_lens[i]].tolist()

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]

pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)

for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1

def get_metric(self, reset=True):
evaluate_result = {}
if not self.only_gross or self.f_type=='macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag]
fn = self._false_negatives[tag]
fp = self._false_positives[tag]
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum + rec
if not self.only_gross and tag!='': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec

if self.f_type == 'macro':
evaluate_result['f'] = f_sum/len(tags)
evaluate_result['pre'] = pre_sum/len(tags)
evaluate_result['rec'] = rec_sum/len(tags)

if self.f_type == 'micro':
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = round(f, 6)
evaluate_result['pre'] = round(pre, 6)
evaluate_result['rec'] = round(rec, 6)

if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

return evaluate_result

def _compute_f_pre_rec(self, tp, fn, fp):
"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)

return f, pre, rec

class BMESF1PreRecMetric(MetricBase):
"""
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B
| | next_B | next_M | next_E | next_S | end |
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:|
| start | 合法 | next_M=B | next_E=S | 合法 | - |
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S |
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E |
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 |
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 |
举例:
prediction为BSEMS,会被认为是SSSSS.

本Metric不检验target的合法性,请务必保证target的合法性。
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。
target形状为 (batch_size, max_len)
seq_lens形状为 (batch_size, )

"""

def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None):
"""
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。

:param b_idx: int, Begin标签所对应的tag idx.
:param m_idx: int, Middle标签所对应的tag idx.
:param e_idx: int, End标签所对应的tag idx.
:param s_idx: int, Single标签所对应的tag idx
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。
"""
super().__init__()

self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)

self.yt_wordnum = 0
self.yp_wordnum = 0
self.corr_num = 0

self.b_idx = b_idx
self.m_idx = m_idx
self.e_idx = e_idx
self.s_idx = s_idx
# 还原init处介绍的矩阵
self._valida_matrix = {
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)],
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)],
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
}

def _validate_tags(self, tags):
"""
给定一个tag的Tensor,返回合法tag

:param tags: Tensor, shape: (seq_len, )
:return: 返回修改为合法tag的list
"""
assert len(tags)!=0
assert isinstance(tags, torch.Tensor) and len(tags.size())==1
padded_tags = [-1, *tags.tolist(), -1]
for idx in range(len(padded_tags)-1):
cur_tag = padded_tags[idx]
if cur_tag not in self._valida_matrix:
cur_tag = self.s_idx
if padded_tags[idx+1] not in self._valida_matrix:
padded_tags[idx+1] = self.s_idx
next_tag = padded_tags[idx+1]
shift_tag = self._valida_matrix[cur_tag][next_tag]
if shift_tag[0]!=-1:
padded_tags[idx+shift_tag[0]] = shift_tag[1]

return padded_tags[1:-1]

def evaluate(self, pred, target, seq_lens):
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

for idx in range(len(pred)):
seq_len = seq_lens[idx]
target_tags = target[idx][:seq_len].tolist()
pred_tags = pred[idx][:seq_len]
pred_tags = self._validate_tags(pred_tags)
start_idx = 0
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)):
if t_tag in (self.s_idx, self.e_idx):
self.yt_wordnum += 1
corr_flag = True
for i in range(start_idx, t_idx+1):
if target_tags[i]!=pred_tags[i]:
corr_flag = False
if corr_flag:
self.corr_num += 1
start_idx = t_idx + 1
if p_tag in (self.s_idx, self.e_idx):
self.yp_wordnum += 1

def get_metric(self, reset=True):
P = self.corr_num / (self.yp_wordnum + 1e-12)
R = self.corr_num / (self.yt_wordnum + 1e-12)
F = 2 * P * R / (P + R + 1e-12)
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)}
if reset:
self.yp_wordnum = 0
self.yt_wordnum = 0
self.corr_num = 0
return evaluate_result


def _prepare_metrics(metrics):
"""


+ 5
- 8
fastNLP/core/trainer.py View File

@@ -27,7 +27,10 @@ from fastNLP.core.utils import get_func_signature


class Trainer(object):
"""
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False):
"""
:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase loss: a loss object
@@ -48,16 +51,10 @@ class Trainer(object):
smaller, add "-" in front of the string. For example::

metric_key="-PPL" # language model gets better as perplexity gets smaller

:param BaseSampler sampler: method used to generate batch data.
:param bool use_tqdm: whether to use tqdm to show train progress.

"""

def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler(), use_tqdm=True):
"""
super(Trainer, self).__init__()

if not isinstance(train_data, DataSet):


+ 5
- 1
fastNLP/io/base_loader.py View File

@@ -3,7 +3,9 @@ import os


class BaseLoader(object):
"""Base loader for all loaders.

"""
def __init__(self):
super(BaseLoader, self).__init__()

@@ -32,7 +34,9 @@ class BaseLoader(object):


class DataLoaderRegister:
""""register for data sets"""
"""Register for all data sets.

"""
_readers = {}

@classmethod


+ 22
- 27
fastNLP/io/config_io.py View File

@@ -6,7 +6,11 @@ from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""
"""Loader for configuration.

:param str data_path: path to the config

"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
@@ -19,13 +23,15 @@ class ConfigLoader(BaseLoader):

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
"""Load section(s) of configuration into the ``sections`` provided. No returns.

:param str file_path: the path of config file
:param dict sections: the dict of ``{section_name(string): ConfigSection object}``
Example::

test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
@@ -60,9 +66,12 @@ class ConfigLoader(BaseLoader):


class ConfigSection(object):
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file.

"""

def __init__(self):
pass
super(ConfigSection, self).__init__()

def __getitem__(self, key):
"""
@@ -132,25 +141,12 @@ class ConfigSection(object):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))


class ConfigSaver(object):
"""ConfigSaver is used to save config file and solve related conflicts.

:param str file_path: path to the config file

"""
def __init__(self, file_path):
self.file_path = file_path
if not os.path.exists(self.file_path):
@@ -244,9 +240,8 @@ class ConfigSaver(object):
def save_config_file(self, section_name, section):
"""This is the function to be called to change the config file with a single section and its name.

:param section_name: The name of section what needs to be changed and saved.
:param section: The section with key and value what needs to be changed and saved.
:return:
:param str section_name: The name of section what needs to be changed and saved.
:param ConfigSection section: The section with key and value what needs to be changed and saved.
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0: # the section not in the file before


+ 85
- 75
fastNLP/io/dataset_loader.py View File

@@ -9,11 +9,12 @@ def convert_seq_dataset(data):
"""Create an DataSet instance that contains no labels.

:param data: list of list of strings, [num_examples, *].
::
[
[word_11, word_12, ...],
...
]
Example::

[
[word_11, word_12, ...],
...
]

:return: a DataSet.
"""
@@ -24,15 +25,16 @@ def convert_seq_dataset(data):


def convert_seq2tag_dataset(data):
"""Convert list of data into DataSet
"""Convert list of data into DataSet.

:param data: list of list of strings, [num_examples, *].
::
[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
Example::

[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]

:return: a DataSet.
"""
@@ -43,15 +45,16 @@ def convert_seq2tag_dataset(data):


def convert_seq2seq_dataset(data):
"""Convert list of data into DataSet
"""Convert list of data into DataSet.

:param data: list of list of strings, [num_examples, *].
::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
Example::

[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]

:return: a DataSet.
"""
@@ -62,20 +65,31 @@ def convert_seq2seq_dataset(data):


class DataSetLoader:
""""loader for data sets"""
"""Interface for all DataSetLoaders.

"""

def load(self, path):
""" load data in `path` into a dataset
"""Load data from a given file.

:param str path: file path
:return: a DataSet object
"""
raise NotImplementedError

def convert(self, data):
"""convert list of data into dataset
"""Optional operation to build a DataSet.

:param data: inner data structure (user-defined) to represent the data.
:return: a DataSet object
"""
raise NotImplementedError


class NativeDataSetLoader(DataSetLoader):
"""A simple example of DataSetLoader

"""
def __init__(self):
super(NativeDataSetLoader, self).__init__()

@@ -90,6 +104,9 @@ DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')


class RawDataSetLoader(DataSetLoader):
"""A simple example of raw data reader

"""
def __init__(self):
super(RawDataSetLoader, self).__init__()

@@ -108,37 +125,35 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')


class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.

In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
"""Dataset Loader for a POS Tag dataset.

In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second
Col is the label. Different sentence are divided by an empty line.
E.g::

Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3

In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label.
"""

def __init__(self):
super(POSDataSetLoader, self).__init__()

def load(self, data_path):
"""
:return data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
Example::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
@@ -188,17 +203,17 @@ class TokenizeDataSetLoader(DataSetLoader):
super(TokenizeDataSetLoader, self).__init__()

def load(self, data_path, max_seq_len=32):
"""
load pku dataset for Chinese word segmentation
"""Load pku dataset for Chinese word segmentation.
CWS (Chinese Word Segmentation) pku training dataset format:
1. Each line is a sentence.
2. Each word in a sentence is separated by space.
1. Each line is a sentence.
2. Each word in a sentence is separated by space.
This function convert the pku dataset into three-level lists with labels <BMES>.
B: beginning of a word
M: middle of a word
E: ending of a word
S: single character
B: beginning of a word
M: middle of a word
E: ending of a word
S: single character

:param str data_path: path to the data set.
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into
several sequences.
:return: three-level lists
@@ -254,11 +269,9 @@ class ClassDataSetLoader(DataSetLoader):
@staticmethod
def parse(lines):
"""
Params
lines: lines from dataset
Return
list(list(list())): the three level of lists are
words, sentence, and dataset

:param lines: lines from dataset
:return: list(list(list())): the three level of lists are words, sentence, and dataset
"""
dataset = list()
for line in lines:
@@ -280,15 +293,9 @@ class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__()

def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
@@ -320,8 +327,8 @@ class ConllLoader(DataSetLoader):
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.
This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""

@@ -467,6 +474,7 @@ class Conll2003Loader(DataSetLoader):

return dataset


class SNLIDataSetLoader(DataSetLoader):
"""A data set loader for SNLI data set.

@@ -478,8 +486,8 @@ class SNLIDataSetLoader(DataSetLoader):
def load(self, path_list):
"""

:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: data_set: A DataSet object.
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: A DataSet object.
"""
assert len(path_list) == 3
line_set = []
@@ -507,12 +515,14 @@ class SNLIDataSetLoader(DataSetLoader):
"""Convert a 3D list to a DataSet object.

:param data: A 3D tensor.
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]
:return: data_set: A DataSet object.
Example::
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]

:return: A DataSet object.
"""

data_set = DataSet()


+ 5
- 4
fastNLP/io/embed_loader.py View File

@@ -38,7 +38,7 @@ class EmbedLoader(BaseLoader):

:param str emb_file: the pre-trained embedding file path
:param str emb_type: the pre-trained embedding data format
:return dict embedding: `{str: np.array}`
:return: a dict of ``{str: np.array}``
"""
if emb_type == 'glove':
return EmbedLoader._load_glove(emb_file)
@@ -53,8 +53,9 @@ class EmbedLoader(BaseLoader):
:param str emb_file: the pre-trained embedding file path.
:param str emb_type: the pre-trained embedding format, support glove now
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
vocab: input vocab or vocab built by pre-train
:return (embedding_tensor, vocab):
embedding_tensor - Tensor of shape (len(word_dict), emb_dim);
vocab - input vocab or vocab built by pre-train

"""
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
@@ -95,7 +96,7 @@ class EmbedLoader(BaseLoader):
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
:param str emb_file: the pre-trained embedding file path.
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return numpy.ndarray embedding_matrix:
:return embedding_matrix: numpy.ndarray

"""
if vocab is None:


+ 5
- 4
fastNLP/io/logger.py View File

@@ -3,15 +3,16 @@ import os


def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO):
"""Return a logger.
"""Create a logger.

:param logger_name: str
:param log_path: str
:param str logger_name:
:param str log_path:
:param log_format:
:param log_level:
:return: logger

to use a logger:
To use a logger::

logger.debug("this is a debug message")
logger.info("this is a info message")
logger.warning("this is a warning message")


+ 9
- 9
fastNLP/io/model_io.py View File

@@ -13,10 +13,10 @@ class ModelLoader(BaseLoader):

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
"""Load model parameters from ".pkl" files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
:param str model_path: the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@@ -24,30 +24,30 @@ class ModelLoader(BaseLoader):
def load_pytorch_model(model_path):
"""Load the entire model.

:param str model_path: the path to the saved model.
"""
return torch.load(model_path)


class ModelSaver(object):
"""Save a model

:param str save_path: the path to the saving directory.
Example::

saver = ModelSaver("./save/model_ckpt_100.pkl")
saver.save_pytorch(model)

"""

def __init__(self, save_path):
"""

:param save_path: str, the path to the saving directory.
"""
self.save_path = save_path

def save_pytorch(self, model, param_only=True):
"""Save a pytorch model into .pkl file.
"""Save a pytorch model into ".pkl" file.

:param model: a PyTorch model
:param param_only: bool, whether only to save the model parameters or the entire model.
:param bool param_only: whether only to save the model parameters or the entire model.

"""
if param_only is True:


+ 13
- 6
fastNLP/models/sequence_modeling.py View File

@@ -1,8 +1,8 @@
import torch
import numpy as np

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.utils import seq_mask


@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling):
Advanced Sequence Labeling Model
"""

def __init__(self, args, emb=None):
def __init__(self, args, emb=None, id2words=None):
super(AdvSeqLabel, self).__init__(args)

vocab_size = args["vocab_size"]
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling):
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
self.norm1 = torch.nn.LayerNorm(word_emb_dim)
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout,
bidirectional=True, batch_first=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3)
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling):
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes)

self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
if id2words is None:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
else:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False,
allowed_transitions=allowed_transitions(id2words,
encoding_type="bmes"))

def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling):
assert 'loss' in kwargs
return kwargs['loss']


if __name__ == '__main__':
args = {
'vocab_size': 20,
@@ -208,11 +215,11 @@ if __name__ == '__main__':
res = model(word_seq, word_seq_len, truth)
loss = res['loss']
pred = res['predict']
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
print('loss: {} acc {}'.format(loss.item(),
((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
optimizer.zero_grad()
loss.backward()
optimizer.step()
curidx = endidx
if curidx == len(data):
curidx = 0


+ 162
- 22
fastNLP/modules/decoder/CRF.py View File

@@ -15,30 +15,153 @@ def seq_len_to_byte_mask(seq_lens):
# return value: ByteTensor, batch_size x max_len
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1)
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1))
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1))
return mask

def allowed_transitions(id2label, encoding_type='bio'):
"""

:param id2label: dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。
:param encoding_type: str, 支持"bio", "bmes"。
:return:List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx).
start_idx=len(id2label), end_idx=len(id2label)+1。
"""
num_tags = len(id2label)
start_idx = num_tags
end_idx = num_tags + 1
encoding_type = encoding_type.lower()
allowed_trans = []
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')]
def split_tag_label(from_label):
from_label = from_label.lower()
if from_label in ['start', 'end']:
from_tag = from_label
from_label = ''
else:
from_tag = from_label[:1]
from_label = from_label[2:]
return from_tag, from_label

for from_id, from_label in id_label_lst:
if from_label in ['<pad>', '<unk>']:
continue
from_tag, from_label = split_tag_label(from_label)
for to_id, to_label in id_label_lst:
if to_label in ['<pad>', '<unk>']:
continue
to_tag, to_label = split_tag_label(to_label)
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
allowed_trans.append((from_id, to_id))
return allowed_trans

def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
"""

:param encoding_type: str, 支持"BIO", "BMES"。
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param from_label: str, 比如"PER", "LOC"等label
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param to_label: str, 比如"PER", "LOC"等label
:return: bool,能否跃迁
"""
if to_tag=='start' or from_tag=='end':
return False
encoding_type = encoding_type.lower()
if encoding_type == 'bio':
"""
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转
+-------+---+---+---+-------+-----+
| | B | I | O | start | end |
+-------+---+---+---+-------+-----+
| B | y | - | y | n | y |
+-------+---+---+---+-------+-----+
| I | y | - | y | n | y |
+-------+---+---+---+-------+-----+
| O | y | n | y | n | y |
+-------+---+---+---+-------+-----+
| start | y | n | y | n | n |
+-------+---+---+---+-------+-----+
| end | n | n | n | n | n |
+-------+---+---+---+-------+-----+
"""
if from_tag == 'start':
return to_tag in ('b', 'o')
elif from_tag in ['b', 'i']:
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label])
elif from_tag == 'o':
return to_tag in ['end', 'b', 'o']
else:
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag))

elif encoding_type == 'bmes':
"""
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转
+-------+---+---+---+---+-------+-----+
| | B | M | E | S | start | end |
+-------+---+---+---+---+-------+-----+
| B | n | - | - | n | n | n |
+-------+---+---+---+---+-------+-----+
| M | n | - | - | n | n | n |
+-------+---+---+---+---+-------+-----+
| E | y | n | n | y | n | y |
+-------+---+---+---+---+-------+-----+
| S | y | n | n | y | n | y |
+-------+---+---+---+---+-------+-----+
| start | y | n | n | y | n | n |
+-------+---+---+---+---+-------+-----+
| end | n | n | n | n | n | n |
+-------+---+---+---+---+-------+-----+
"""
if from_tag == 'start':
return to_tag in ['b', 's']
elif from_tag == 'b':
return to_tag in ['m', 'e'] and from_label==to_label
elif from_tag == 'm':
return to_tag in ['m', 'e'] and from_label==to_label
elif from_tag in ['e', 's']:
return to_tag in ['b', 's', 'end']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag))
else:
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type))


class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None):
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None):
"""
:param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag

:param num_tags: int, 标签的数量。
:param include_start_end_trans: bool, 是否包含起始tag
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]]. 允许的跃迁,可以通过allowed_transitions()得到。
如果为None,则所有跃迁均为合法
:param initial_method:
"""

super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
self.tag_size = tag_size
self.num_tags = num_tags

# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size))
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(tag_size))
self.end_scores = nn.Parameter(torch.randn(tag_size))
self.start_scores = nn.Parameter(torch.randn(num_tags))
self.end_scores = nn.Parameter(torch.randn(num_tags))

if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
self._constrain = nn.Parameter(constrain, requires_grad=False)

# self.reset_parameter()
initial_parameter(self, initial_method)

def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
@@ -49,7 +172,7 @@ class ConditionalRandomField(nn.Module):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
:param logits:FloatTensor, max_len x batch_size x tag_size
:param logits:FloatTensor, max_len x batch_size x num_tags
:param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
"""
@@ -72,7 +195,7 @@ class ConditionalRandomField(nn.Module):
def _glod_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x tag_size
:param logits: FloatTensor, max_len x batch_size x num_tags
:param tags: LongTensor, max_len x batch_size
:param mask: ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size
@@ -99,7 +222,7 @@ class ConditionalRandomField(nn.Module):
def forward(self, feats, tags, mask):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x tag_size
:param feats:FloatTensor, batch_size x max_len x num_tags
:param tags:LongTensor, batch_size x max_len
:param mask:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size
@@ -112,13 +235,20 @@ class ConditionalRandomField(nn.Module):

return all_path_score - gold_path_score

def viterbi_decode(self, data, mask, get_score=False):
def viterbi_decode(self, data, mask, get_score=False, unpad=False):
"""
Given a feats matrix, return best decode path and best score.
:param data:FloatTensor, batch_size x max_len x tag_size
:param data:FloatTensor, batch_size x max_len x num_tags
:param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score.
:return: scores, paths
:param unpad: bool, 是否将结果unpad,
如果False, 返回的是batch_size x max_len的tensor,
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个
List[int]的长度是这个sample的有效长度
:return: 如果get_score为False,返回结果根据unpadding变动
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float]
为每个seqence的解码分数。

"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
@@ -127,19 +257,23 @@ class ConditionalRandomField(nn.Module):
# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = data[0]
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
vscore += self.start_scores.view(1, -1)
transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags+1] += self.end_scores.data

vscore += transitions[n_tags, :n_tags]
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = data[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags).data
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)

if self.include_start_end_trans:
vscore += self.end_scores.view(1, -1)
vscore += transitions[:n_tags, n_tags+1].view(1, -1)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device)
@@ -154,7 +288,13 @@ class ConditionalRandomField(nn.Module):
for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i+1], batch_idx] = last_tags

ans = ans.transpose(0, 1)
if unpad:
paths = []
for idx, seq_len in enumerate(lens):
paths.append(ans[idx, :seq_len+1].tolist())
else:
paths = ans
if get_score:
return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)
return paths, ans_score.tolist()
return paths

+ 26
- 5
reproduction/chinese_word_segment/cws_io/cws_reader.py View File

@@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader


def cut_long_sentence(sent, max_sample_length=200):
"""
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length

:param sent: str.
:param max_sample_length: int.
:return: list of str.
"""
sent_no_space = sent.replace(' ', '')
cutted_sentence = []
if len(sent_no_space) > max_sample_length:
@@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader):
return dataset


class ConlluCWSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。
class ConllCWSReader(object):
def __init__(self):
pass

def load(self, path, cut_long_sent=False):
"""
返回的DataSet只包含raw_sentence这个field,内容为str。
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -150,10 +171,10 @@ class ConlluCWSReader(object):
ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
res = self.get_char_lst(sample)
if res is None:
continue
line = ' '.join(res)
line = ' '.join(res)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
@@ -163,7 +184,7 @@ class ConlluCWSReader(object):

return ds

def get_one(self, sample):
def get_char_lst(self, sample):
if len(sample)==0:
return None
text = []


+ 23
- 6
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask

class CWSBiLSTMEncoder(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1):
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1):
super().__init__()

self.input_size = 0
@@ -68,6 +68,7 @@ class CWSBiLSTMEncoder(BaseModel):
if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
x_tensor = self.embedding_drop(x_tensor)
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)

@@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel):


from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.modules.decoder.CRF import allowed_transitions

class CWSBiLSTMCRF(BaseModel):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4):
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4):
"""
默认使用BMES的标注方式
:param vocab_num:
:param embed_dim:
:param bigram_vocab_num:
:param bigram_embed_dim:
:param num_bigram_per_char:
:param hidden_size:
:param bidirectional:
:param embed_drop_p:
:param num_layers:
:param tag_size:
"""
super(CWSBiLSTMCRF, self).__init__()

self.tag_size = tag_size
@@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel):

size_layer = [hidden_size, 200, tag_size]
self.decoder_model = MLP(size_layer)
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False)
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes')
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False,
allowed_transitions=allowed_trans)


def forward(self, chars, tags, seq_lens, bigrams=None):
def forward(self, chars, target, seq_lens, bigrams=None):
device = self.parameters().__next__().device
chars = chars.to(device).long()
if not bigrams is None:
@@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel):
masks = seq_lens_to_mask(seq_lens)
feats = self.encoder_model(chars, bigrams, seq_lens)
feats = self.decoder_model(feats)
losses = self.crf(feats, tags, masks)
losses = self.crf(feats, target, masks)

pred_dict = {}
pred_dict['seq_lens'] = seq_lens
@@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel):
feats = self.decoder_model(feats)
probs = self.crf.viterbi_decode(feats, masks, get_score=False)

return {'pred_tags': probs}
return {'pred': probs}


+ 211
- 44
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -1,17 +1,18 @@

import re


from fastNLP.core.field import SeqLabelField
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from reproduction.chinese_word_segment.process.span_converter import SpanConverter

_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'

class SpeicalSpanProcessor(Processor):
# 这个类会将句子中的special span转换为对应的内容。
"""
将DataSet中field_name使用span_converter替换掉。

"""
def __init__(self, field_name, new_added_field_name=None):
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name)

@@ -20,11 +21,12 @@ class SpeicalSpanProcessor(Processor):

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
def inner_proc(ins):
sentence = ins[self.field_name]
for span_converter in self.span_converters:
sentence = span_converter.find_certain_span_and_replace(sentence)
ins[self.new_added_field_name] = sentence
return sentence
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)

return dataset

@@ -34,17 +36,22 @@ class SpeicalSpanProcessor(Processor):
self.span_converters.append(converter)



class CWSCharSegProcessor(Processor):
"""
将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学',
' ', 'f', 'u', ...]

"""
def __init__(self, field_name, new_added_field_name):
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
def inner_proc(ins):
sentence = ins[self.field_name]
chars = self._split_sent_into_chars(sentence)
ins[self.new_added_field_name] = chars
return chars
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)

return dataset

@@ -73,6 +80,10 @@ class CWSCharSegProcessor(Processor):


class CWSTagProcessor(Processor):
"""
为分词生成tag。该class为Base class。

"""
def __init__(self, field_name, new_added_field_name=None):
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name)

@@ -107,18 +118,22 @@ class CWSTagProcessor(Processor):

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
def inner_proc(ins):
sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence)
ins[self.new_added_field_name] = tag_list
dataset.set_target(**{self.new_added_field_name:True})
dataset._set_need_tensor(**{self.new_added_field_name:True})
return tag_list
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)
dataset.set_target(self.new_added_field_name)
return dataset

def _tags_from_word_len(self, word_len):
raise NotImplementedError

class CWSBMESTagProcessor(CWSTagProcessor):
"""
通过DataSet中的field_name这个field生成相应的BMES的tag。

"""
def __init__(self, field_name, new_added_field_name=None):
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name)

@@ -137,6 +152,10 @@ class CWSBMESTagProcessor(CWSTagProcessor):
return tag_list

class CWSSegAppTagProcessor(CWSTagProcessor):
"""
通过DataSet中的field_name这个field生成相应的SegApp的tag。

"""
def __init__(self, field_name, new_added_field_name=None):
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name)

@@ -151,6 +170,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor):


class BigramProcessor(Processor):
"""
这是生成bigram的基类。

"""
def __init__(self, field_name, new_added_fielf_name=None):

super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
@@ -158,22 +181,31 @@ class BigramProcessor(Processor):
def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))

for ins in dataset:
def inner_proc(ins):
characters = ins[self.field_name]
bigrams = self._generate_bigram(characters)
ins[self.new_added_field_name] = bigrams
return bigrams
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)

return dataset


def _generate_bigram(self, characters):
pass


class Pre2Post2BigramProcessor(BigramProcessor):
def __init__(self, field_name, new_added_fielf_name=None):
"""
该bigram processor生成bigram的方式如下
原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为
[L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....]
即每个汉字,会有八个bigram, 对于上例中'a'的bigram为
['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac']
返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。

"""
def __init__(self, field_name, new_added_field_name=None):

super(BigramProcessor, self).__init__(field_name, new_added_fielf_name)
super(BigramProcessor, self).__init__(field_name, new_added_field_name)

def _generate_bigram(self, characters):
bigrams = []
@@ -197,20 +229,107 @@ class Pre2Post2BigramProcessor(BigramProcessor):
# 这里需要建立vocabulary了,但是遇到了以下的问题
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用
# Processor了
# TODO 如何将建立vocab和index这两步统一了?

class VocabIndexerProcessor(Processor):
"""
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
new_added_field_name, 则覆盖原有的field_name.

"""
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
verbose=1, is_input=True):
"""

:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
:param min_freq: 创建的Vocabulary允许的单词最少出现次数.
:param max_size: 创建的Vocabulary允许的最大的单词数量
:param verbose: 0, 不输出任何信息;1,输出信息
:param bool is_input:
"""
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
self.min_freq = min_freq
self.max_size = max_size

self.verbose =verbose
self.is_input = is_input

def construct_vocab(self, *datasets):
"""
使用传入的DataSet创建vocabulary

:param datasets: DataSet类型的数据,用于构建vocabulary
:return:
"""
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
self.vocab.build_vocab()
if self.verbose:
print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))

def process(self, *datasets, only_index_dataset=None):
"""
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
后,则会index datasets与only_index_dataset。

:param datasets: DataSet类型的数据
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
:return:
"""
if len(datasets)==0 and not hasattr(self,'vocab'):
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
if not hasattr(self, 'vocab'):
self.construct_vocab(*datasets)
else:
if self.verbose:
print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
to_index_datasets = []
if len(datasets)!=0:
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)

if not (only_index_dataset is None):
if isinstance(only_index_dataset, list):
for dataset in only_index_dataset:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)
elif isinstance(only_index_dataset, DataSet):
to_index_datasets.append(only_index_dataset)
else:
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))

for dataset in to_index_datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
new_field_name=self.new_added_field_name, is_input=self.is_input)
# 只返回一个,infer时为了跟其他processor保持一致
if len(to_index_datasets) == 1:
return to_index_datasets[0]

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
self.vocab = vocab

def delete_vocab(self):
del self.vocab

def get_vocab_size(self):
return len(self.vocab)

class VocabProcessor(Processor):
def __init__(self, field_name, min_count=1, max_vocab_size=None):
def __init__(self, field_name, min_freq=1, max_size=None):

super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size)
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size)

def process(self, *datasets):
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
tokens = ins[self.field_name]
self.vocab.update(tokens)

dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))

def get_vocab(self):
self.vocab.build_vocab()
@@ -220,19 +339,6 @@ class VocabProcessor(Processor):
return len(self.vocab)


class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):

super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
length = len(ins[self.field_name])
ins[self.new_added_field_name] = length
dataset._set_need_tensor(**{self.new_added_field_name:True})
return dataset

class SegApp2OutputProcessor(Processor):
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'):
super(SegApp2OutputProcessor, self).__init__(None, None)
@@ -258,7 +364,32 @@ class SegApp2OutputProcessor(Processor):


class BMES2OutputProcessor(Processor):
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'):
"""
按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B
| | next_B | next_M | next_E | next_S | end |
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:|
| start | 合法 | next_M=B | next_E=S | 合法 | - |
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S |
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E |
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 |
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 |
举例:
prediction为BSEMS,会被认为是SSSSS.

"""
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output',
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3):
"""

:param chars_field_name: character所对应的field
:param tag_field_name: 预测对应的field
:param new_added_field_name: 转换后的内容所在field
:param b_idx: int, Begin标签所对应的tag idx.
:param m_idx: int, Middle标签所对应的tag idx.
:param e_idx: int, End标签所对应的tag idx.
:param s_idx: int, Single标签所对应的tag idx
"""
super(BMES2OutputProcessor, self).__init__(None, None)

self.chars_field_name = chars_field_name
@@ -266,19 +397,55 @@ class BMES2OutputProcessor(Processor):

self.new_added_field_name = new_added_field_name

self.b_idx = b_idx
self.m_idx = m_idx
self.e_idx = e_idx
self.s_idx = s_idx
# 还原init处介绍的矩阵
self._valida_matrix = {
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)],
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)],
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)],
}

def _validate_tags(self, tags):
"""
给定一个tag的List,返回合法tag

:param tags: Tensor, shape: (seq_len, )
:return: 返回修改为合法tag的list
"""
assert len(tags)!=0
padded_tags = [-1, *tags, -1]
for idx in range(len(padded_tags)-1):
cur_tag = padded_tags[idx]
if cur_tag not in self._valida_matrix:
cur_tag = self.s_idx
if padded_tags[idx+1] not in self._valida_matrix:
padded_tags[idx+1] = self.s_idx
next_tag = padded_tags[idx+1]
shift_tag = self._valida_matrix[cur_tag][next_tag]
if shift_tag[0]!=-1:
padded_tags[idx+shift_tag[0]] = shift_tag[1]

return padded_tags[1:-1]

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
def inner_proc(ins):
pred_tags = ins[self.tag_field_name]
pred_tags = self._validate_tags(pred_tags)
chars = ins[self.chars_field_name]
words = []
start_idx = 0
for idx, tag in enumerate(pred_tags):
if tag==3:
# 当前没有考虑将原文替换回去
if tag==self.s_idx:
words.extend(chars[start_idx:idx+1])
start_idx = idx + 1
elif tag==2:
elif tag==self.e_idx:
words.append(''.join(chars[start_idx:idx+1]))
start_idx = idx + 1
ins[self.new_added_field_name] = ' '.join(words)
return ' '.join(words)
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)

+ 229
- 0
reproduction/chinese_word_segment/train_context.py View File

@@ -0,0 +1,229 @@

from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.api.processor import SeqLenProcessor
from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor


from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF

from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1

ds_name = 'msr'

tr_filename = '/home/hyan/ctb3/train.conllx'
dev_filename = '/home/hyan/ctb3/dev.conllx'


reader = ConllCWSReader()

tr_dataset = reader.load(tr_filename, cut_long_sent=True)
dev_dataset = reader.load(dev_filename)

print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset)))

# 1. 准备processor
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')

char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst')
tag_proc = CWSBMESTagProcessor('raw_sentence', 'target')

bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst')

char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars')
bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4)

seq_len_proc = SeqLenProcessor('chars')

# 2. 使用processor
fs2hs_proc(tr_dataset)

char_proc(tr_dataset)
tag_proc(tr_dataset)
bigram_proc(tr_dataset)

char_vocab_proc(tr_dataset)
bigram_vocab_proc(tr_dataset)
seq_len_proc(tr_dataset)

# 2.1 处理dev_dataset
fs2hs_proc(dev_dataset)

char_proc(dev_dataset)
tag_proc(dev_dataset)
bigram_proc(dev_dataset)

char_vocab_proc(dev_dataset)
bigram_vocab_proc(dev_dataset)
seq_len_proc(dev_dataset)

dev_dataset.set_input('chars', 'bigrams', 'target')
tr_dataset.set_input('chars', 'bigrams', 'target')
dev_dataset.set_target('seq_lens')
tr_dataset.set_target('seq_lens')

print("Finish preparing data.")


# 3. 得到数据集可以用于训练了
# TODO pretrain的embedding是怎么解决的?

import torch
from torch import optim


tag_size = tag_proc.tag_size

cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=0.2,
num_layers=1, tag_size=tag_size)
cws_model.cuda()

num_epochs = 5
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02)

from fastNLP.core.trainer import Trainer
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.metrics import BMESF1PreRecMetric

metric = BMESF1PreRecMetric(target='tags')
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3,
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None,
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True)

trainer.train()
exit(0)

#
# print_every = 50
# batch_size = 32
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
# num_batch_per_epoch = len(tr_dataset) // batch_size
# best_f1 = 0
# best_epoch = 0
# for num_epoch in range(num_epochs):
# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10)
# sys.stdout.flush()
# avg_loss = 0
# with tqdm(total=num_batch_per_epoch, leave=True) as pbar:
# pbar.set_description_str('Epoch:%d' % (num_epoch + 1))
# cws_model.train()
# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
# optimizer.zero_grad()
#
# tags = batch_y['tags'].long()
# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size
#
# seq_lens = pred_dict['seq_lens']
# masks = seq_lens_to_mask(seq_lens).float()
# tags = tags.to(seq_lens.device)
#
# loss = pred_dict['loss']
#
# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())
#
# avg_loss += loss.item()
#
# loss.backward()
# for group in optimizer.param_groups:
# for param in group['params']:
# param.grad.clamp_(-5, 5)
#
# optimizer.step()
#
# if batch_idx % print_every == 0:
# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
# avg_loss = 0
# pbar.update(print_every)
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
# # 验证集
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes')
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
# pre*100,
# rec*100))
# if best_f1<f1:
# best_f1 = f1
# # 缓存最佳的parameter,可能之后会用于保存
# best_state_dict = {
# key:value.clone() for key, value in
# cws_model.state_dict().items()
# }
# best_epoch = num_epoch
#
# cws_model.load_state_dict(best_state_dict)

# 4. 组装需要存下的内容
pp = Pipeline()
pp.add_processor(fs2hs_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(tag_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_vocab_proc)
pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc)

# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_filename = '/home/hyan/ctb3/test.conllx'
te_dataset = reader.load(te_filename)
pp(te_dataset)

from fastNLP.core.tester import Tester

tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False,
verbose=1)
#
# batch_size = 64
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
# pre * 100,
# rec * 100))

# TODO 这里貌似需要区分test pipeline与infer pipeline

test_context_dict = {'pipeline': pp,
'model': cws_model}
torch.save(test_context_dict, 'models/test_context_crf.pkl')


# 5. dev的pp
# 4. 组装需要存下的内容

from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor

model_proc = ModelProcessor(cws_model)
output_proc = BMES2OutputProcessor()

pp = Pipeline()
pp.add_processor(fs2hs_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_vocab_proc)
pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc)

pp.add_processor(model_proc)
pp.add_processor(output_proc)


# TODO 这里貌似需要区分test pipeline与infer pipeline

infer_context_dict = {'pipeline': pp}
# torch.save(infer_context_dict, 'models/cws_crf.pkl')


# TODO 还需要考虑如何替换回原文的问题?
# 1. 不需要将特殊tag替换
# 2. 需要将特殊tag替换回去

reproduction/pos_tag_model/process/pos_processor.py → reproduction/pos_tag_model/pos_processor.py View File

@@ -4,6 +4,7 @@ from collections import Counter
from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet


class CombineWordAndPosProcessor(Processor):
def __init__(self, word_field_name, pos_field_name):
super(CombineWordAndPosProcessor, self).__init__(None, None)
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor):

return dataset


class PosOutputStrProcessor(Processor):
def __init__(self, word_field_name, pos_field_name):
super(PosOutputStrProcessor, self).__init__(None, None)

reproduction/pos_tag_model/pos_io/pos_reader.py → reproduction/pos_tag_model/pos_reader.py View File

@@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200):
return cutted_sentence


class ConlluPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。
class ConllPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。
def __init__(self):
pass

@@ -70,6 +70,70 @@ class ConlluPOSReader(object):

return ds



class ZhConllPOSReader(object):
# 中文colln格式reader
def __init__(self):
pass

def load(self, path):
"""
返回的DataSet, 包含以下的field
words:list of str,
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
char_seq.extend(list(word))
if len(word)==1:
pos_seq.append('S-{}'.format(tag))
elif len(word)>1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word)-2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds

def get_one(self, sample):
if len(sample)==0:
return None
@@ -84,6 +148,6 @@ class ConlluPOSReader(object):
return text, pos_tags

if __name__ == '__main__':
reader = ConlluPOSReader()
reader = ZhConllPOSReader()
d = reader.load('/home/hyan/train.conllx')
print('reader')
print(d)

+ 82
- 0
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -0,0 +1,82 @@
import os
import sys

import torch

# in order to run fastNLP without installation
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import SeqLenProcessor
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.trainer import Trainer
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor

cfgfile = './pos_tag.cfg'
pickle_path = "save"


def train():
# load config
train_param = ConfigSection()
model_param = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param})
print("config loaded")

# Data Loader
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx")
print(dataset)
print("dataset transformed")

dataset.rename_field("tag", "truth")

vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq")
tag_proc = VocabIndexerProcessor("truth")
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True)

vocab_proc(dataset)
tag_proc(dataset)
seq_len_proc(dataset)

dataset.set_input("word_seq", "word_seq_origin_len", "truth")
dataset.set_target("truth", "word_seq_origin_len")

print("processors defined")

# dataset.set_is_target(tag_ids=True)
model_param["vocab_size"] = vocab_proc.get_vocab_size()
model_param["num_classes"] = tag_proc.get_vocab_size()
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"]))

# define a model
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word)

# call trainer to train
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dataset, metric_key="f",
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save")
trainer.train()

# save model & pipeline
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len")
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag")

pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag])
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
torch.save(save_dict, "model_pp.pkl")
print("pipeline saved")


def infer():
pass


if __name__ == "__main__":
train()

+ 229
- 0
test/core/test_metrics.py View File

@@ -4,6 +4,7 @@ import numpy as np
import torch

from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import BMESF1PreRecMetric
from fastNLP.core.metrics import pred_topk, accuracy_topk


@@ -132,6 +133,234 @@ class TestAccuracyMetric(unittest.TestCase):
return
self.assertTrue(True, False), "No exception catches."

class SpanF1PreRecMetric(unittest.TestCase):
def test_case1(self):
from fastNLP.core.metrics import bmes_tag_to_spans
from fastNLP.core.metrics import bio_tag_to_spans

bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8']
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8']
expect_bmes_res = set()
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)),
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))])
expect_bio_res = set()
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)),
('6', (4, 4)), ('7', (6, 6))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
# for i in range(1000):
# strs = list(map(str, np.random.randint(100, size=1000)))
# bmes = list('bmes'.upper())
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))]
# bio = list('bio'.upper())
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))]
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))

def test_case2(self):
# 测试不带label的
from fastNLP.core.metrics import bmes_tag_to_spans
from fastNLP.core.metrics import bio_tag_to_spans

bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E']
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O']
expect_bmes_res = set()
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))])
expect_bio_res = set()
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))])
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
# for i in range(1000):
# bmes = list('bmes'.upper())
# bmes_strs = np.random.choice(bmes, size=1000)
# bio = list('bio'.upper())
# bio_strs = np.random.choice(bio, size=100)
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))

def tese_case3(self):
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric
# 与allennlp测试能否正确计算f metric
#
def generate_allen_tags(encoding_type, number_labels=4):
vocab = {}
for i in range(number_labels):
label = str(i)
for tag in encoding_type:
if tag == 'O':
if tag not in vocab:
vocab['O'] = len(vocab) + 1
continue
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count
return vocab

number_labels = 4
# bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
bio_sequence = torch.FloatTensor(
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011,
0.0470, 0.0971],
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
0.7987, -0.3970],
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
0.6880, 1.4348],
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
-1.6876, -0.8917],
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
1.4217, 0.2622]],

[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
1.3592, -0.8973],
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
-0.4025, -0.3417],
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
0.2861, -0.3966],
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
0.0213, 1.4777],
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
1.3024, 0.2001]]]
)
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.],
[5., 6., 8., 6., 0.]])
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target})
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775,
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0,
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845,
'f': 0.12499999999994846}
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())

#bmes tag
bmes_sequence = torch.FloatTensor(
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352,
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332,
-0.3505, -0.6002],
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641,
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002,
0.4439, 0.2514],
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696,
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753,
0.5802, -0.0516],
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121,
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390,
-0.9242, -0.0333],
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393,
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809,
-0.3779, -0.3195]],

[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753,
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957,
-0.1103, 0.4417],
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049,
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631,
-0.6386, 0.2797],
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947,
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522,
1.1237, -1.5385],
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977,
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376,
-0.0591, -0.2461],
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097,
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142,
-0.7344, -1.2046]]]
)
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.],
[ 6., 15., 6., 15., 5.]])

fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})

expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001,
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775,
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314,
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504,
'pre': 0.499999999999995, 'rec': 0.499999999999995}

self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)

# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary
# from allennlp.training.metrics import SpanBasedF1Measure
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)},
# non_padded_namespaces=['tags'])
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags')
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1))
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20))
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20))
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
#
# def convert_allen_res_to_fastnlp_res(metric_result):
# allen_result = {}
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"}
# for key, value in metric_result.items():
# if key in key_map:
# key = key_map[key]
# else:
# label = key.split('-')[-1]
# if key.startswith('f1'):
# key = 'f-{}'.format(label)
# else:
# key = '{}-{}'.format(key[:3], label)
# allen_result[key] = value
# return allen_result
#
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()))
# # print(fastnlp_bio_metric.get_metric())
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()),
# fastnlp_bio_metric.get_metric())
#
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)})
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES')
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels))
# bmes_target = torch.randint(4 * number_labels, size=(2, 20))
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20))
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
#
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()))
# # print(fastnlp_bmes_metric.get_metric())
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()),
# fastnlp_bmes_metric.get_metric())

class TestBMESF1PreRecMetric(unittest.TestCase):
def test_case1(self):
seq_lens = torch.LongTensor([4, 2])
pred = torch.randn(2, 4, 4)
target = torch.LongTensor([[0, 1, 2, 3],
[3, 3, 0, 0]])
pred_dict = {'pred': pred}
target_dict = {'target': target, 'seq_lens': seq_lens}

metric = BMESF1PreRecMetric()
metric(pred_dict, target_dict)
metric.get_metric()

def test_case2(self):
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1}
seq_lens = torch.LongTensor([4, 2])
target = torch.LongTensor([[0, 1, 2, 3],
[3, 3, 0, 0]])
pred_dict = {'pred': target}
target_dict = {'target': target, 'seq_lens': seq_lens}

metric = BMESF1PreRecMetric()
metric(pred_dict, target_dict)
self.assertDictEqual(metric.get_metric(), {'f1': 0.999999, 'precision': 1.0, 'recall': 1.0})

class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数


+ 104
- 0
test/modules/decoder/test_CRF.py View File

@@ -0,0 +1,104 @@

import unittest


class TestCRF(unittest.TestCase):
def test_case1(self):
# 检查allowed_transitions()能否正确使用
from fastNLP.modules.decoder.CRF import allowed_transitions

id2label = {0: 'B', 1: 'I', 2:'O'}
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
(2, 4), (3, 0), (3, 2)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))

id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))

id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
allowed_transitions(id2label)

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx:label for idx, label in enumerate(labels)}
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))

def test_case2(self):
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
pass
# import torch
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask
#
# labels = ['O']
# for label in ['X', 'Y']:
# for tag in 'BI':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
# include_start_end_transitions=False)
# batch_size = 3
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
# trans_m = allen_CRF.transitions
# seq_lens = torch.randint(1, 20, size=(batch_size,))
# seq_lens[-1] = 20
# mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask)
#
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])
#
#
# labels = []
# for label in ['X', 'Y']:
# for tag in 'BMES':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
# include_start_end_transitions=False)
# batch_size = 3
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
# trans_m = allen_CRF.transitions
# seq_lens = torch.randint(1, 20, size=(batch_size,))
# seq_lens[-1] = 20
# mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask)
#
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES'))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])



Loading…
Cancel
Save