@@ -0,0 +1,5 @@ | |||||
ignore: | |||||
- "reproduction" # ignore folders and all its contents | |||||
- "setup.py" | |||||
- "docs" | |||||
- "tutorials" |
@@ -1,7 +1,8 @@ | |||||
fastNLP上手教程 | |||||
fastNLP 10分钟上手教程 | |||||
=============== | =============== | ||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb | |||||
fastNLP提供方便的数据预处理,训练和测试模型的功能 | fastNLP提供方便的数据预处理,训练和测试模型的功能 | ||||
DataSet & Instance | DataSet & Instance | ||||
@@ -2,6 +2,8 @@ | |||||
FastNLP 1分钟上手教程 | FastNLP 1分钟上手教程 | ||||
===================== | ===================== | ||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb | |||||
step 1 | step 1 | ||||
------ | ------ | ||||
@@ -0,0 +1,5 @@ | |||||
fastNLP 进阶教程 | |||||
=============== | |||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb | |||||
@@ -0,0 +1,5 @@ | |||||
fastNLP 开发者指南 | |||||
=============== | |||||
原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md | |||||
@@ -5,6 +5,7 @@ Installation | |||||
.. contents:: | .. contents:: | ||||
:local: | :local: | ||||
Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt . | |||||
Run the following commands to install fastNLP package: | Run the following commands to install fastNLP package: | ||||
@@ -6,4 +6,6 @@ Quickstart | |||||
../tutorials/fastnlp_1_minute_tutorial | ../tutorials/fastnlp_1_minute_tutorial | ||||
../tutorials/fastnlp_10tmin_tutorial | ../tutorials/fastnlp_10tmin_tutorial | ||||
../tutorials/fastnlp_advanced_tutorial | |||||
../tutorials/fastnlp_developer_guide | |||||
@@ -18,26 +18,27 @@ print(cws.predict(text)) | |||||
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | # ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | ||||
``` | ``` | ||||
### 中文分词+词性标注 | |||||
### 词性标注 | |||||
```python | ```python | ||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
# 输入已分词序列 | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import POS | from fastNLP.api import POS | ||||
pos = POS(device='cpu') | pos = POS(device='cpu') | ||||
print(pos.predict(text)) | print(pos.predict(text)) | ||||
# [['编者/NN', '按/P', ':/PU', '7月/NT', '12日/NR', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一/OD', '款高/NN', '科技/NN', '隐形/NN', '无/VE', '人机/NN', '雷电/NN', '之/DEG', '神/NN', '。/PU'], ['这/DT', '款/NN', '飞行/VV', '从/P', '外型/NN', '上/LC', '来/MSP', '看/VV', '酷似/VV', '电影/NN', '中/LC', '的/DEG', '太空/NN', '飞行器/NN', ',/PU', '据/P', '英国/NR', '方面/NN', '介绍/VV', ',/PU', '可以/VV', '实现/VV', '洲际/NN', '远程/NN', '打击/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无/VE', '人机/NN', '到底/AD', '有/VE', '多/CD', '厉害/NN', '?/PU']] | |||||
# [['编者/NN', '按:/NN', '7月/NT', '12日/NT', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一款/NN', '高科技/NN', '隐形/AD', '无人机/VV', '雷电之神/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无人机/VV', '到底/AD', '有/VE', '多/AD', '厉害/VA', '?/PU']] | |||||
``` | ``` | ||||
### 中文分词+词性标注+句法分析 | |||||
### 句法分析 | |||||
```python | ```python | ||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import Parser | from fastNLP.api import Parser | ||||
parser = Parser(device='cpu') | parser = Parser(device='cpu') | ||||
print(parser.predict(text)) | print(parser.predict(text)) | ||||
# [['12/nsubj', '12/prep', '2/punct', '5/nn', '2/pobj', '12/punct', '11/nn', '11/nn', '11/nn', '11/nn', '2/pobj', '0/root', '12/asp', '15/det', '16/nsubj', '21/rcmod', '16/cpm', '21/nummod', '21/nn', '21/nn', '22/top', '12/ccomp', '24/nn', '26/assmod', '24/assm', '22/dobj', '12/punct'], ['2/det', '8/xsubj', '8/mmod', '8/prep', '6/lobj', '4/plmod', '8/prtmod', '0/root', '8/ccomp', '11/lobj', '14/assmod', '11/assm', '14/nn', '9/dobj', '8/punct', '22/prep', '18/nn', '19/nsubj', '16/pccomp', '22/punct', '22/mmod', '8/dep', '25/nn', '25/nn', '22/dobj', '8/punct'], ['4/advmod', '3/det', '4/nsubj', '0/root', '4/dobj', '7/advmod', '4/conj', '9/nummod', '7/dobj', '4/punct']] | |||||
# [['2/nn', '4/nn', '4/nn', '20/tmod', '11/punct', '10/nn', '10/nn', '10/nn', '10/nn', '11/nsubj', '20/dep', '11/asp', '14/det', '15/nsubj', '18/rcmod', '15/cpm', '18/nn', '11/dobj', '20/advmod', '0/root', '20/dobj', '20/punct'], ['4/advmod', '3/det', '8/xsubj', '8/dep', '8/advmod', '8/dep', '8/advmod', '0/root', '8/punct']] | |||||
``` | ``` | ||||
完整样例见`examples.py` | 完整样例见`examples.py` |
@@ -9,9 +9,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.utils import load_url | from fastNLP.api.utils import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -19,9 +17,9 @@ from fastNLP.api.processor import IndexerProcessor | |||||
# TODO add pretrain urls | # TODO add pretrain urls | ||||
model_urls = { | model_urls = { | ||||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", | |||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl", | |||||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" | |||||
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl", | |||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", | |||||
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl" | |||||
} | } | ||||
@@ -31,6 +29,16 @@ class API: | |||||
self._dict = None | self._dict = None | ||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
"""Do prediction for the given input. | |||||
""" | |||||
raise NotImplementedError | |||||
def test(self, file_path): | |||||
"""Test performance over the given data set. | |||||
:param str file_path: | |||||
:return: a dictionary of metric values | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def load(self, path, device): | def load(self, path, device): | ||||
@@ -69,12 +77,11 @@ class POS(API): | |||||
if not hasattr(self, "pipeline"): | if not hasattr(self, "pipeline"): | ||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
sentence_list = [] | |||||
sentence_list = content | |||||
# 1. 检查sentence的类型 | # 1. 检查sentence的类型 | ||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
for sentence in sentence_list: | |||||
if not all((type(obj) == str for obj in sentence)): | |||||
raise ValueError("Input must be list of list of string.") | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
@@ -83,36 +90,28 @@ class POS(API): | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
def decode_tags(ins): | |||||
pred_tags = ins["tag"] | |||||
chars = ins["words"] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag[0] == "S": | |||||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
elif tag[0] == "E": | |||||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
return words | |||||
dataset.apply(decode_tags, new_field_name="tag_output") | |||||
output = dataset.field_arrays["tag_output"].content | |||||
def merge_tag(words_list, tags_list): | |||||
rtn = [] | |||||
for words, tags in zip(words_list, tags_list): | |||||
rtn.append([w + "/" + t for w, t in zip(words, tags)]) | |||||
return rtn | |||||
output = dataset.field_arrays["tag"].content | |||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
return output | |||||
return merge_tag(content, output) | |||||
def test(self, file_path): | def test(self, file_path): | ||||
test_data = ZhConllPOSReader().load(file_path) | |||||
test_data = ConllxDataLoader().load(file_path) | |||||
tag_vocab = self._dict["tag_vocab"] | |||||
pipeline = self._dict["pipeline"] | |||||
save_dict = self._dict | |||||
tag_vocab = save_dict["tag_vocab"] | |||||
pipeline = save_dict["pipeline"] | |||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | ||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | pipeline.pipeline = [index_tag] + pipeline.pipeline | ||||
test_data.rename_field("pos_tags", "tag") | |||||
pipeline(test_data) | pipeline(test_data) | ||||
test_data.set_target("truth") | test_data.set_target("truth") | ||||
prediction = test_data.field_arrays["predict"].content | prediction = test_data.field_arrays["predict"].content | ||||
@@ -226,7 +225,7 @@ class CWS(API): | |||||
rec = eval_res['BMESF1PreRecMetric']['rec'] | rec = eval_res['BMESF1PreRecMetric']['rec'] | ||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | ||||
return f1, pre, rec | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
class Parser(API): | class Parser(API): | ||||
@@ -251,6 +250,7 @@ class Parser(API): | |||||
dataset.add_field('wp', pos_out) | dataset.add_field('wp', pos_out) | ||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | ||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | ||||
dataset.rename_field("words", "raw_words") | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
@@ -260,31 +260,74 @@ class Parser(API): | |||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | # output like: [['2/top', '0/root', '4/nn', '2/dep']] | ||||
return dataset.field_arrays['output'].content | return dataset.field_arrays['output'].content | ||||
def test(self, filepath): | |||||
data = ConllxDataLoader().load(filepath) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
def load_test_file(self, path): | |||||
def get_one(sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
return data_list | |||||
def test(self, filepath): | |||||
data = self.load_test_file(filepath) | |||||
def convert(data): | |||||
BOS = '<BOS>' | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + sample[2] | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(raw_words=word_seq, | |||||
pos=pos_seq, | |||||
gold_heads=heads, | |||||
arc_true=heads, | |||||
tags=head_tags)) | |||||
return dataset | |||||
ds = convert(data) | |||||
pp = self.pipeline | pp = self.pipeline | ||||
for p in pp: | for p in pp: | ||||
if p.field_name == 'word_list': | if p.field_name == 'word_list': | ||||
p.field_name = 'gold_words' | p.field_name = 'gold_words' | ||||
elif p.field_name == 'pos_list': | elif p.field_name == 'pos_list': | ||||
p.field_name = 'gold_pos' | p.field_name = 'gold_pos' | ||||
# ds.rename_field("words", "raw_words") | |||||
# ds.rename_field("tag", "pos") | |||||
pp(ds) | pp(ds) | ||||
head_cor, label_cor, total = 0, 0, 0 | head_cor, label_cor, total = 0, 0, 0 | ||||
for ins in ds: | for ins in ds: | ||||
head_gold = ins['gold_heads'] | head_gold = ins['gold_heads'] | ||||
head_pred = ins['heads'] | |||||
head_pred = ins['arc_pred'] | |||||
length = len(head_gold) | length = len(head_gold) | ||||
total += length | total += length | ||||
for i in range(length): | for i in range(length): | ||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | head_cor += 1 if head_pred[i] == head_gold[i] else 0 | ||||
uas = head_cor / total | uas = head_cor / total | ||||
print('uas:{:.2f}'.format(uas)) | |||||
# print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | for p in pp: | ||||
if p.field_name == 'gold_words': | if p.field_name == 'gold_words': | ||||
@@ -292,7 +335,7 @@ class Parser(API): | |||||
elif p.field_name == 'gold_pos': | elif p.field_name == 'gold_pos': | ||||
p.field_name = 'pos_list' | p.field_name = 'pos_list' | ||||
return uas | |||||
return {"USA": round(uas, 5)} | |||||
class Analyzer: | class Analyzer: | ||||
@@ -15,15 +15,42 @@ def chinese_word_segmentation(): | |||||
print(cws.predict(text)) | print(cws.predict(text)) | ||||
def chinese_word_segmentation_test(): | |||||
cws = CWS(device='cpu') | |||||
print(cws.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def pos_tagging(): | def pos_tagging(): | ||||
# 输入已分词序列 | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
pos = POS(device='cpu') | pos = POS(device='cpu') | ||||
print(pos.predict(text)) | print(pos.predict(text)) | ||||
def pos_tagging_test(): | |||||
pos = POS(device='cpu') | |||||
print(pos.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def syntactic_parsing(): | def syntactic_parsing(): | ||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
parser = Parser(device='cpu') | parser = Parser(device='cpu') | ||||
print(parser.predict(text)) | print(parser.predict(text)) | ||||
def syntactic_parsing_test(): | |||||
parser = Parser(device='cpu') | |||||
print(parser.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# chinese_word_segmentation() | |||||
# chinese_word_segmentation_test() | |||||
# pos_tagging() | |||||
# pos_tagging_test() | |||||
syntactic_parsing() | syntactic_parsing() | ||||
# syntactic_parsing_test() |
@@ -102,6 +102,7 @@ class PreAppendProcessor(Processor): | |||||
[data] + instance[field_name] | [data] + instance[field_name] | ||||
""" | """ | ||||
def __init__(self, data, field_name, new_added_field_name=None): | def __init__(self, data, field_name, new_added_field_name=None): | ||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.data = data | self.data = data | ||||
@@ -116,6 +117,7 @@ class SliceProcessor(Processor): | |||||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | 从某个field中只取部分内容。等价于instance[field_name][start:end:step] | ||||
""" | """ | ||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | def __init__(self, start, end, step, field_name, new_added_field_name=None): | ||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | super(SliceProcessor, self).__init__(field_name, new_added_field_name) | ||||
for o in (start, end, step): | for o in (start, end, step): | ||||
@@ -132,6 +134,7 @@ class Num2TagProcessor(Processor): | |||||
将一句话中的数字转换为某个tag。 | 将一句话中的数字转换为某个tag。 | ||||
""" | """ | ||||
def __init__(self, tag, field_name, new_added_field_name=None): | def __init__(self, tag, field_name, new_added_field_name=None): | ||||
""" | """ | ||||
@@ -163,6 +166,7 @@ class IndexerProcessor(Processor): | |||||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | ||||
['我', '是', xxx] | ['我', '是', xxx] | ||||
""" | """ | ||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -215,6 +219,7 @@ class SeqLenProcessor(Processor): | |||||
根据某个field新增一个sequence length的field。取该field的第一维 | 根据某个field新增一个sequence length的field。取该field的第一维 | ||||
""" | """ | ||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | ||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.is_input = is_input | self.is_input = is_input | ||||
@@ -229,6 +234,7 @@ class SeqLenProcessor(Processor): | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
@@ -292,6 +298,7 @@ class Index2WordProcessor(Processor): | |||||
将DataSet中某个为index的field根据vocab转换为str | 将DataSet中某个为index的field根据vocab转换为str | ||||
""" | """ | ||||
def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
@@ -303,7 +310,6 @@ class Index2WordProcessor(Processor): | |||||
class SetTargetProcessor(Processor): | class SetTargetProcessor(Processor): | ||||
# TODO; remove it. | |||||
def __init__(self, *fields, flag=True): | def __init__(self, *fields, flag=True): | ||||
super(SetTargetProcessor, self).__init__(None, None) | super(SetTargetProcessor, self).__init__(None, None) | ||||
self.fields = fields | self.fields = fields | ||||
@@ -313,6 +319,7 @@ class SetTargetProcessor(Processor): | |||||
dataset.set_target(*self.fields, flag=self.flag) | dataset.set_target(*self.fields, flag=self.flag) | ||||
return dataset | return dataset | ||||
class SetInputProcessor(Processor): | class SetInputProcessor(Processor): | ||||
def __init__(self, *fields, flag=True): | def __init__(self, *fields, flag=True): | ||||
super(SetInputProcessor, self).__init__(None, None) | super(SetInputProcessor, self).__init__(None, None) | ||||
@@ -322,3 +329,103 @@ class SetInputProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
dataset.set_input(*self.fields, flag=self.flag) | dataset.set_input(*self.fields, flag=self.flag) | ||||
return dataset | return dataset | ||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose = verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets) == 0 and not hasattr(self, 'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets) != 0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose |
@@ -2,7 +2,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
import torch.multiprocessing as mp | |||||
class Batch(object): | class Batch(object): | ||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
@@ -16,10 +16,11 @@ class Batch(object): | |||||
:param int batch_size: the size of the batch | :param int batch_size: the size of the batch | ||||
:param Sampler sampler: a Sampler object | :param Sampler sampler: a Sampler object | ||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | ||||
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. | |||||
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): | |||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
@@ -28,16 +29,12 @@ class Batch(object): | |||||
self.curidx = 0 | self.curidx = 0 | ||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | ||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
self.prefetch = prefetch | |||||
self.lengths = 0 | |||||
def __iter__(self): | |||||
self.idx_list = self.sampler(self.dataset) | |||||
self.curidx = 0 | |||||
self.lengths = self.dataset.get_length() | |||||
return self | |||||
def __next__(self): | |||||
def fetch_one(self): | |||||
if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
raise StopIteration | |||||
return None | |||||
else: | else: | ||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
batch_x, batch_y = {}, {} | batch_x, batch_y = {}, {} | ||||
@@ -48,7 +45,7 @@ class Batch(object): | |||||
for field_name, field in self.dataset.get_all_fields().items(): | for field_name, field in self.dataset.get_all_fields().items(): | ||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy: | |||||
if not self.as_numpy and field.padder is not None: | |||||
batch = to_tensor(batch, field.dtype) | batch = to_tensor(batch, field.dtype) | ||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
@@ -56,9 +53,29 @@ class Batch(object): | |||||
batch_x[field_name] = batch | batch_x[field_name] = batch | ||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
def __iter__(self): | |||||
""" | |||||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||||
:return: | |||||
""" | |||||
if self.prefetch: | |||||
return run_batch_iter(self) | |||||
def batch_iter(): | |||||
self.init_iter() | |||||
while 1: | |||||
res = self.fetch_one() | |||||
if res is None: | |||||
break | |||||
yield res | |||||
return batch_iter() | |||||
def init_iter(self): | |||||
self.idx_list = self.sampler(self.dataset) | |||||
self.curidx = 0 | |||||
self.lengths = self.dataset.get_length() | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_batches | return self.num_batches | ||||
@@ -67,8 +84,50 @@ class Batch(object): | |||||
def to_tensor(batch, dtype): | def to_tensor(batch, dtype): | ||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (float, np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
try: | |||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (float, np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
except: | |||||
pass | |||||
return batch | return batch | ||||
def run_fetch(batch, q): | |||||
batch.init_iter() | |||||
# print('start fetch') | |||||
while 1: | |||||
res = batch.fetch_one() | |||||
# print('fetch one') | |||||
q.put(res) | |||||
if res is None: | |||||
# print('fetch done, waiting processing') | |||||
q.join() | |||||
break | |||||
# print('fetch exit') | |||||
def run_batch_iter(batch): | |||||
q = mp.JoinableQueue(maxsize=10) | |||||
fetch_p = mp.Process(target=run_fetch, args=(batch, q)) | |||||
fetch_p.daemon = True | |||||
fetch_p.start() | |||||
# print('fork fetch process') | |||||
while 1: | |||||
try: | |||||
res = q.get(timeout=1) | |||||
q.task_done() | |||||
# print('get fetched') | |||||
if res is None: | |||||
break | |||||
yield res | |||||
except Exception as e: | |||||
if fetch_p.is_alive(): | |||||
continue | |||||
else: | |||||
break | |||||
fetch_p.terminate() | |||||
fetch_p.join() | |||||
# print('iter done') | |||||
@@ -1,3 +1,11 @@ | |||||
import os | |||||
import torch | |||||
from tensorboardX import SummaryWriter | |||||
from fastNLP.io.model_io import ModelSaver, ModelLoader | |||||
class Callback(object): | class Callback(object): | ||||
"""An Interface for all callbacks. | """An Interface for all callbacks. | ||||
@@ -7,38 +15,42 @@ class Callback(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | |||||
def before_train(self): | |||||
def on_train_begin(self): | |||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
# at the beginning of each epoch | # at the beginning of each epoch | ||||
pass | pass | ||||
def before_batch(self, batch_x, batch_y, indices): | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | |||||
# at the beginning of each step/mini-batch | # at the beginning of each step/mini-batch | ||||
pass | pass | ||||
def before_loss(self, batch_y, predict_y): | |||||
def on_loss_begin(self, batch_y, predict_y): | |||||
# after data_forward, and before loss computation | # after data_forward, and before loss computation | ||||
pass | pass | ||||
def before_backward(self, loss, model): | |||||
def on_backward_begin(self, loss, model): | |||||
# after loss computation, and before gradient backward | # after loss computation, and before gradient backward | ||||
pass | pass | ||||
def after_backward(self, model): | |||||
def on_backward_end(self, model): | |||||
pass | pass | ||||
def after_step(self, optimizer): | |||||
def on_step_end(self, optimizer): | |||||
pass | pass | ||||
def after_batch(self, *args): | |||||
def on_batch_end(self, *args): | |||||
# at the end of each step/mini-batch | # at the end of each step/mini-batch | ||||
pass | pass | ||||
def after_valid(self, eval_result, metric_key, optimizer): | |||||
def on_valid_begin(self): | |||||
pass | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | 每次执行验证机的evaluation后会调用。传入eval_result | ||||
@@ -49,7 +61,7 @@ class Callback(object): | |||||
""" | """ | ||||
pass | pass | ||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
""" | """ | ||||
每个epoch结束将会调用该方法 | 每个epoch结束将会调用该方法 | ||||
@@ -60,7 +72,7 @@ class Callback(object): | |||||
""" | """ | ||||
pass | pass | ||||
def after_train(self, model): | |||||
def on_train_end(self, model): | |||||
""" | """ | ||||
训练结束,调用该方法 | 训练结束,调用该方法 | ||||
@@ -69,16 +81,16 @@ class Callback(object): | |||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, exception, model, indices): | |||||
def on_exception(self, exception, model): | |||||
""" | """ | ||||
当训练过程出现异常,会触发该方法 | 当训练过程出现异常,会触发该方法 | ||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | :param exception: 某种类型的Exception,比如KeyboardInterrupt等 | ||||
:param model: 传入Trainer的模型 | :param model: 传入Trainer的模型 | ||||
:param indices: 当前batch的index | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def transfer(func): | def transfer(func): | ||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | """装饰器,将对CallbackManager的调用转发到各个Callback子类. | ||||
@@ -125,91 +137,95 @@ class CallbackManager(Callback): | |||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | ||||
@transfer | @transfer | ||||
def before_train(self): | |||||
def on_train_begin(self): | |||||
pass | |||||
@transfer | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def before_batch(self, batch_x, batch_y, indices): | |||||
def on_loss_begin(self, batch_y, predict_y): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def before_loss(self, batch_y, predict_y): | |||||
def on_backward_begin(self, loss, model): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def before_backward(self, loss, model): | |||||
def on_backward_end(self, model): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_backward(self, model): | |||||
def on_step_end(self, optimizer): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_step(self, optimizer): | |||||
def on_batch_end(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_batch(self): | |||||
def on_valid_begin(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_valid(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def after_train(self, model): | |||||
def on_train_end(self, model): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_exception(self, exception, model, indices): | |||||
def on_exception(self, exception, model): | |||||
pass | pass | ||||
class DummyCallback(Callback): | class DummyCallback(Callback): | ||||
def before_train(self, *arg): | |||||
def on_train_begin(self, *arg): | |||||
print(arg) | print(arg) | ||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
print(cur_epoch, n_epoch, optimizer) | print(cur_epoch, n_epoch, optimizer) | ||||
class EchoCallback(Callback): | class EchoCallback(Callback): | ||||
def before_train(self): | |||||
def on_train_begin(self): | |||||
print("before_train") | print("before_train") | ||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
print("before_epoch") | print("before_epoch") | ||||
def before_batch(self, batch_x, batch_y, indices): | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | |||||
print("before_batch") | print("before_batch") | ||||
def before_loss(self, batch_y, predict_y): | |||||
def on_loss_begin(self, batch_y, predict_y): | |||||
print("before_loss") | print("before_loss") | ||||
def before_backward(self, loss, model): | |||||
def on_backward_begin(self, loss, model): | |||||
print("before_backward") | print("before_backward") | ||||
def after_batch(self): | |||||
def on_batch_end(self): | |||||
print("after_batch") | print("after_batch") | ||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
print("after_epoch") | print("after_epoch") | ||||
def after_train(self, model): | |||||
def on_train_end(self, model): | |||||
print("after_train") | print("after_train") | ||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | ||||
""" | |||||
每次backward前,将parameter的gradient clip到某个范围。 | |||||
"""每次backward前,将parameter的gradient clip到某个范围。 | |||||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | :param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | ||||
的model中所有参数进行clip | 的model中所有参数进行clip | ||||
@@ -231,12 +247,229 @@ class GradientClipCallback(Callback): | |||||
self.parameters = parameters | self.parameters = parameters | ||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def after_backward(self, model): | |||||
def on_backward_end(self, model): | |||||
self.clip_fun(model.parameters(), self.clip_value) | self.clip_fun(model.parameters(), self.clip_value) | ||||
class CallbackException(BaseException): | |||||
def __init__(self, msg): | |||||
super(CallbackException, self).__init__(msg) | |||||
class EarlyStopError(CallbackException): | |||||
def __init__(self, msg): | |||||
super(EarlyStopError, self).__init__(msg) | |||||
class EarlyStopCallback(Callback): | |||||
def __init__(self, patience): | |||||
""" | |||||
:param int patience: 停止之前等待的epoch数 | |||||
""" | |||||
super(EarlyStopCallback, self).__init__() | |||||
self.trainer = None # override by CallbackManager | |||||
self.patience = patience | |||||
self.wait = 0 | |||||
self.epoch = 0 | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
self.epoch += 1 | |||||
if not self.trainer._better_eval_result(eval_result): | |||||
# current result is getting worse | |||||
if self.wait == self.patience: | |||||
raise EarlyStopError("Early stopping raised.") | |||||
else: | |||||
self.wait += 1 | |||||
else: | |||||
self.wait = 0 | |||||
def on_exception(self, exception, model): | |||||
if isinstance(exception, EarlyStopError): | |||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||||
else: | |||||
raise exception # 抛出陌生Error | |||||
class LRScheduler(Callback): | |||||
def __init__(self, lr_scheduler): | |||||
"""对PyTorch LR Scheduler的包装 | |||||
:param lr_scheduler: PyTorch的lr_scheduler | |||||
""" | |||||
super(LRScheduler, self).__init__() | |||||
import torch.optim | |||||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||||
self.scheduler = lr_scheduler | |||||
else: | |||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
self.scheduler.step() | |||||
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||||
class ControlC(Callback): | |||||
def __init__(self, quit_all): | |||||
""" | |||||
:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
""" | |||||
super(ControlC, self).__init__() | |||||
if type(quit_all) != bool: | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||||
self.quit_all = quit_all | |||||
def on_exception(self, exception, model): | |||||
if isinstance(exception, KeyboardInterrupt): | |||||
if self.quit_all is True: | |||||
import sys | |||||
sys.exit(0) # 直接退出程序 | |||||
else: | |||||
pass | |||||
else: | |||||
raise exception # 抛出陌生Error | |||||
class SmoothValue(object): | |||||
def __init__(self, beta: float): | |||||
self.beta, self.n, self.mov_avg = beta, 0, 0 | |||||
self.smooth = None | |||||
def add_value(self, val: float) -> None: | |||||
"Add `val` to calculate updated smoothed value." | |||||
self.n += 1 | |||||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | |||||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | |||||
class LRFinder(Callback): | |||||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | |||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||||
:param n_batch: 一个epoch内的iteration数 | |||||
:param start_lr: 学习率下界 | |||||
:param end_lr: 学习率上界 | |||||
""" | |||||
super(LRFinder, self).__init__() | |||||
self.start_lr, self.end_lr = start_lr, end_lr | |||||
self.num_it = n_batch | |||||
self.stop = False | |||||
self.best_loss = 0. | |||||
self.best_lr = None | |||||
self.loss_history = [] | |||||
self.smooth_value = SmoothValue(0.8) | |||||
self.opt = None | |||||
scale = (self.end_lr - self.start_lr) / self.num_it | |||||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | |||||
self.find = None | |||||
self.loader = ModelLoader() | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
if cur_epoch == 1: | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | |||||
self.opt.param_groups[0]["lr"] = self.start_lr | |||||
# save model | |||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||||
self.find = True | |||||
def on_backward_begin(self, loss, model): | |||||
if self.find: | |||||
if torch.isnan(loss) or self.stop is True: | |||||
self.stop = True | |||||
return | |||||
loss_val = loss.detach().cpu().data | |||||
self.loss_history.append(loss_val) | |||||
self.smooth_value.add_value(loss_val) | |||||
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss: | |||||
self.best_loss = self.smooth_value.smooth | |||||
self.best_lr = self.opt.param_groups[0]["lr"] | |||||
def on_batch_end(self, *args): | |||||
if self.find: | |||||
lr = next(self.lr_gen, None) | |||||
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss: | |||||
self.stop = True | |||||
return | |||||
self.opt.param_groups[0]["lr"] = lr | |||||
# self.loader.load_pytorch(self.trainer.model, "tmp") | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
if cur_epoch == 1: | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | |||||
self.find = False | |||||
# reset model | |||||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | |||||
print("Model reset. \nFind best lr={}".format(self.best_lr)) | |||||
class TensorboardCallback(Callback): | |||||
""" | |||||
接受以下一个或多个字符串作为参数: | |||||
- "model" | |||||
- "loss" | |||||
- "metric" | |||||
""" | |||||
def __init__(self, *options): | |||||
super(TensorboardCallback, self).__init__() | |||||
args = {"model", "loss", "metric"} | |||||
for opt in options: | |||||
if opt not in args: | |||||
raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args)) | |||||
self.options = options | |||||
self._summary_writer = None | |||||
self.graph_added = False | |||||
def on_train_begin(self): | |||||
save_dir = self.trainer.save_path | |||||
if save_dir is None: | |||||
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||||
else: | |||||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | |||||
if "model" in self.options and self.graph_added is False: | |||||
# tesorboardX 这里有大bug,暂时没法画模型图 | |||||
# from fastNLP.core.utils import _build_args | |||||
# inputs = _build_args(self.trainer.model, **batch_x) | |||||
# args = tuple([value for value in inputs.values()]) | |||||
# args = args[0] if len(args) == 1 else args | |||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | |||||
self.graph_added = True | |||||
def on_backward_begin(self, loss, model): | |||||
if "loss" in self.options: | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | |||||
if "model" in self.options: | |||||
for name, param in self.trainer.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) | |||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | |||||
global_step=self.trainer.step) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
if "metric" in self.options: | |||||
for name, metric in eval_result.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.trainer.step) | |||||
def on_train_end(self, model): | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
def on_exception(self, exception, model): | |||||
if hasattr(self, "_summary_writer"): | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | ||||
manager.before_train(10, 11, 12) | |||||
manager.on_train_begin(10, 11, 12) | |||||
# print(manager.after_epoch()) | # print(manager.after_epoch()) |
@@ -2,6 +2,7 @@ import _pickle as pickle | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.fieldarray import AutoPadder | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
@@ -88,12 +89,13 @@ class DataSet(object): | |||||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") | raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") | ||||
data_set = DataSet() | data_set = DataSet() | ||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(name=field.name, | |||||
fields=field.content[idx], | |||||
padding_val=field.padding_val, | |||||
is_input=field.is_input, | |||||
is_target=field.is_target) | |||||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | |||||
is_input=field.is_input, is_target=field.is_target) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | |||||
if idx not in self: | |||||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||||
return self.field_arrays[idx] | |||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -144,19 +146,23 @@ class DataSet(object): | |||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
# DataSet has no field yet | # DataSet has no field yet | ||||
for name, field in ins.fields.items(): | for name, field in ins.fields.items(): | ||||
self.field_arrays[name] = FieldArray(name, [field]) | |||||
field = field.tolist() if isinstance(field, np.ndarray) else field | |||||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||||
else: | else: | ||||
assert len(self.field_arrays) == len(ins.fields) | |||||
if len(self.field_arrays) != len(ins.fields): | |||||
raise ValueError( | |||||
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." | |||||
.format(len(self.field_arrays), len(ins.fields))) | |||||
for name, field in ins.fields.items(): | for name, field in ins.fields.items(): | ||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False): | |||||
"""Add a new field to the DataSet. | """Add a new field to the DataSet. | ||||
:param str name: the name of the field. | :param str name: the name of the field. | ||||
:param fields: a list of int, float, or other objects. | :param fields: a list of int, float, or other objects. | ||||
:param int padding_val: integer for padding. | |||||
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | |||||
:param bool is_input: whether this field is model input. | :param bool is_input: whether this field is model input. | ||||
:param bool is_target: whether this field is label or target. | :param bool is_target: whether this field is label or target. | ||||
""" | """ | ||||
@@ -164,8 +170,8 @@ class DataSet(object): | |||||
if len(self) != len(fields): | if len(self) != len(fields): | ||||
raise RuntimeError(f"The field to append must have the same size as dataset. " | raise RuntimeError(f"The field to append must have the same size as dataset. " | ||||
f"Dataset size {len(self)} != field size {len(fields)}") | f"Dataset size {len(self)} != field size {len(fields)}") | ||||
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||||
is_input=is_input) | |||||
self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, | |||||
padder=padder) | |||||
def delete_field(self, name): | def delete_field(self, name): | ||||
"""Delete a field based on the field name. | """Delete a field based on the field name. | ||||
@@ -229,6 +235,25 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
def set_padder(self, field_name, padder): | |||||
""" | |||||
为field_name设置padder | |||||
:param field_name: str, 设置field的padding方式为padder | |||||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||||
:return: | |||||
""" | |||||
self.field_arrays[field_name].set_padder(padder) | |||||
def set_pad_val(self, field_name, pad_val): | |||||
""" | |||||
为某个 | |||||
:param field_name: str,修改该field的pad_val | |||||
:param pad_val: int,该field的padder会以pad_val作为padding index | |||||
:return: | |||||
""" | |||||
self.field_arrays[field_name].set_pad_val(pad_val) | |||||
def get_input_name(self): | def get_input_name(self): | ||||
"""Get all field names with `is_input` as True. | """Get all field names with `is_input` as True. | ||||
@@ -254,7 +279,7 @@ class DataSet(object): | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
""" | """ | ||||
results = [func(ins) for ins in self._inner_iter()] | results = [func(ins) for ins in self._inner_iter()] | ||||
if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | raise ValueError("{} always return None.".format(get_func_signature(func=func))) | ||||
extra_param = {} | extra_param = {} | ||||
@@ -270,12 +295,11 @@ class DataSet(object): | |||||
extra_param['is_input'] = old_field.is_input | extra_param['is_input'] = old_field.is_input | ||||
if 'is_target' not in extra_param: | if 'is_target' not in extra_param: | ||||
extra_param['is_target'] = old_field.is_target | extra_param['is_target'] = old_field.is_target | ||||
self.add_field(name=new_field_name, | |||||
fields=results, | |||||
padding_val=old_field.padding_val, | |||||
**extra_param) | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||||
is_target=extra_param["is_target"]) | |||||
else: | else: | ||||
self.add_field(name=new_field_name, fields=results, **extra_param) | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||||
is_target=extra_param.get("is_target", None)) | |||||
else: | else: | ||||
return results | return results | ||||
@@ -314,8 +338,17 @@ class DataSet(object): | |||||
for field_name in self.field_arrays: | for field_name in self.field_arrays: | ||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | ||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | ||||
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | ||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | ||||
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
return train_set, dev_set | return train_set, dev_set | ||||
@@ -1,51 +1,168 @@ | |||||
import numpy as np | import numpy as np | ||||
class PadderBase: | |||||
""" | |||||
所有padder都需要继承这个类,并覆盖__call__()方法。 | |||||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||||
""" | |||||
def __init__(self, pad_val=0, **kwargs): | |||||
self.pad_val = pad_val | |||||
def set_pad_val(self, pad_val): | |||||
self.pad_val = pad_val | |||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
""" | |||||
传入的是List内容。假设有以下的DataSet。 | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
dataset = DataSet() | |||||
dataset.append(Instance(word='this is a demo', length=4, | |||||
chars=[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']])) | |||||
dataset.append(Instance(word='another one', length=2, | |||||
chars=[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']])) | |||||
# 如果batch_size=2, 下面只是用str的方式看起来更直观一点,但实际上可能word和chars在pad时都已经为index了。 | |||||
word这个field的pad_func会接收到的内容会是 | |||||
[ | |||||
'this is a demo', | |||||
'another one' | |||||
] | |||||
length这个field的pad_func会接收到的内容会是 | |||||
[4, 2] | |||||
chars这个field的pad_func会接收到的内容会是 | |||||
[ | |||||
[['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||||
[['a', 'n', 'o', 't', 'h', 'e', 'r'], ['o', 'n', 'e']] | |||||
] | |||||
即把每个instance中某个field的内容合成一个List传入 | |||||
:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||||
deepcopy一份。 | |||||
:param field_name: str, field的名称,帮助定位错误 | |||||
:param field_ele_dtype: np.int64, np.float64, np.str. 该field的内层list元素的类型。辅助判断是否pad,大多数情况用不上 | |||||
:return: List[padded_element]或np.array([padded_element]) | |||||
""" | |||||
raise NotImplementedError | |||||
class AutoPadder(PadderBase): | |||||
""" | |||||
根据contents的数据自动判定是否需要做padding。 | |||||
(1) 如果元素类型(元素类型是指field中最里层List的元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行padding | |||||
(2) 如果元素类型为(np.int64, np.float64), | |||||
(2.1) 如果该field的内容只有一个,比如为sequence_length, 则不进行padding | |||||
(2.2) 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
如果某个instance中field为[1, 2, 3],则可以pad; 若为[[1,2], [3,4, ...]]则不能进行pad | |||||
""" | |||||
def __init__(self, pad_val=0): | |||||
""" | |||||
:param pad_val: int, padding的位置使用该index | |||||
""" | |||||
super().__init__(pad_val=pad_val) | |||||
def _is_two_dimension(self, contents): | |||||
""" | |||||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | |||||
:param contents: | |||||
:return: | |||||
""" | |||||
value = contents[0] | |||||
if isinstance(value , (np.ndarray, list)): | |||||
value = value[0] | |||||
if isinstance(value, (np.ndarray, list)): | |||||
return False | |||||
return True | |||||
return False | |||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
if not is_iterable(contents[0]): | |||||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||||
max_len = max([len(content) for content in contents]) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content in enumerate(contents): | |||||
array[i][:len(content)] = content | |||||
else: # should only be str | |||||
array = np.array([content for content in contents]) | |||||
return array | |||||
class FieldArray(object): | class FieldArray(object): | ||||
"""``FieldArray`` is the collection of ``Instance``s of the same field. | """``FieldArray`` is the collection of ``Instance``s of the same field. | ||||
It is the basic element of ``DataSet`` class. | It is the basic element of ``DataSet`` class. | ||||
:param str name: the name of the FieldArray | :param str name: the name of the FieldArray | ||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | :param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. | ||||
:param int padding_val: the integer for padding. Default: 0. | |||||
:param bool is_target: If True, this FieldArray is used to compute loss. | :param bool is_target: If True, this FieldArray is used to compute loss. | ||||
:param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
:param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) | |||||
""" | """ | ||||
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): | |||||
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): | |||||
"""DataSet在初始化时会有两类方法对FieldArray操作: | |||||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||||
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]}) | |||||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||||
然后后面的样本使用FieldArray.append进行添加。 | |||||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | |||||
""" | |||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | if isinstance(content, list): | ||||
content = content | |||||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | |||||
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] | |||||
for idx, item in enumerate(content): | |||||
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) | |||||
# 将[np.array] 转化为 list of list | |||||
# 也可以支持[array, array, array]的情况 | |||||
if isinstance(item, np.ndarray): | |||||
content[idx] = content[idx].tolist() | |||||
elif isinstance(content, np.ndarray): | elif isinstance(content, np.ndarray): | ||||
content = content.tolist() # convert np.ndarray into 2-D list | content = content.tolist() # convert np.ndarray into 2-D list | ||||
else: | else: | ||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | ||||
self.content = content | |||||
self.padding_val = padding_val | |||||
if len(content) == 0: | |||||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | |||||
self._is_target = None | |||||
self._is_input = None | |||||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | |||||
self.content_dim = None # 表示content是多少维的list | |||||
self.set_padder(padder) | |||||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | |||||
self.BASIC_TYPES = (int, float, str, np.ndarray) | |||||
self.is_2d_list = False | |||||
self.pytype = None # int, float, str, or np.ndarray | |||||
self.dtype = None # np.int64, np.float64, np.str | |||||
self.pytype = None | |||||
self.dtype = None | |||||
self._is_input = None | |||||
self._is_target = None | |||||
if is_input is not None: | |||||
if is_input is not None or is_target is not None: | |||||
self.is_input = is_input | self.is_input = is_input | ||||
if is_target is not None: | |||||
self.is_target = is_target | self.is_target = is_target | ||||
def _set_dtype(self): | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
return self._is_input | return self._is_input | ||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
""" | |||||
当 field_array.is_input = True / False 时被调用 | |||||
""" | |||||
if value is True: | if value is True: | ||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._set_dtype() | |||||
self._is_input = value | self._is_input = value | ||||
@property | @property | ||||
@@ -54,46 +171,99 @@ class FieldArray(object): | |||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
""" | |||||
当 field_array.is_target = True / False 时被调用 | |||||
""" | |||||
if value is True: | if value is True: | ||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
self._set_dtype() | |||||
self._is_target = value | self._is_target = value | ||||
def _type_detection(self, content): | def _type_detection(self, content): | ||||
""" | |||||
:param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||||
:return type: one of int, float, str, np.ndarray | |||||
"""当该field被设置为is_input或者is_target时被调用 | |||||
""" | """ | ||||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||||
# content is a 2-D list | |||||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||||
raise TypeError("Please provide 2-D list.") | |||||
type_set = set([self._type_detection(x) for x in content]) | |||||
if len(type_set) == 2 and int in type_set and float in type_set: | |||||
type_set = {float} | |||||
elif len(type_set) > 1: | |||||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||||
self.is_2d_list = True | |||||
if len(content) == 0: | |||||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | |||||
type_set = set([type(item) for item in content]) | |||||
if list in type_set: | |||||
if len(type_set) > 1: | |||||
# list 跟 非list 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
# >1维list | |||||
inner_type_set = set() | |||||
for l in content: | |||||
[inner_type_set.add(type(obj)) for obj in l] | |||||
if list not in inner_type_set: | |||||
# 二维list | |||||
self.content_dim = 2 | |||||
return self._basic_type_detection(inner_type_set) | |||||
else: | |||||
if len(inner_type_set) == 1: | |||||
# >2维list | |||||
inner_inner_type_set = set() | |||||
for _2d_list in content: | |||||
for _1d_list in _2d_list: | |||||
[inner_inner_type_set.add(type(obj)) for obj in _1d_list] | |||||
if list in inner_inner_type_set: | |||||
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") | |||||
# 3维list | |||||
self.content_dim = 3 | |||||
return self._basic_type_detection(inner_inner_type_set) | |||||
else: | |||||
# list 跟 非list 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||||
else: | |||||
# 一维list | |||||
for content_type in type_set: | |||||
if content_type not in self.BASIC_TYPES: | |||||
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( | |||||
self.name, self.BASIC_TYPES, content_type)) | |||||
self.content_dim = 1 | |||||
return self._basic_type_detection(type_set) | |||||
def _basic_type_detection(self, type_set): | |||||
""" | |||||
:param type_set: a set of Python types | |||||
:return: one of self.BASIC_TYPES | |||||
""" | |||||
if len(type_set) == 1: | |||||
return type_set.pop() | return type_set.pop() | ||||
elif isinstance(content, list): | |||||
# content is a 1-D list | |||||
if len(content) == 0: | |||||
# the old error is not informative enough. | |||||
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||||
type_set = set([type(item) for item in content]) | |||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||||
return type_set.pop() | |||||
elif len(type_set) == 2 and float in type_set and int in type_set: | |||||
elif len(type_set) == 2: | |||||
# 有多个basic type; 可能需要up-cast | |||||
if float in type_set and int in type_set: | |||||
# up-cast int to float | # up-cast int to float | ||||
return float | return float | ||||
else: | else: | ||||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||||
# str 跟 int 或者 float 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
else: | else: | ||||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||||
# str, int, float混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
def _1d_list_check(self, val): | |||||
"""如果不是1D list就报错 | |||||
""" | |||||
type_set = set((type(obj) for obj in val)) | |||||
if any(obj not in self.BASIC_TYPES for obj in type_set): | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
self._basic_type_detection(type_set) | |||||
# otherwise: _basic_type_detection will raise error | |||||
return True | |||||
def _2d_list_check(self, val): | |||||
"""如果不是2D list 就报错 | |||||
""" | |||||
type_set = set(type(obj) for obj in val) | |||||
if list(type_set) != [list]: | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
inner_type_set = set() | |||||
for l in val: | |||||
for obj in l: | |||||
inner_type_set.add(type(obj)) | |||||
self._basic_type_detection(inner_type_set) | |||||
return True | |||||
@staticmethod | @staticmethod | ||||
def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
@@ -108,38 +278,39 @@ class FieldArray(object): | |||||
:param val: int, float, str, or a list of one. | :param val: int, float, str, or a list of one. | ||||
""" | """ | ||||
if self.is_target is True or self.is_input is True: | |||||
# only check type when used as target or input | |||||
if isinstance(val, list): | |||||
pass | |||||
elif isinstance(val, tuple): # 确保最外层是list | |||||
val = list(val) | |||||
elif isinstance(val, np.ndarray): | |||||
val = val.tolist() | |||||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
val_type = type(val) | |||||
if val_type == list: # shape check | |||||
if self.is_2d_list is False: | |||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||||
if self.is_input is True or self.is_target is True: | |||||
if type(val) == list: | |||||
if len(val) == 0: | if len(val) == 0: | ||||
raise RuntimeError("Cannot append an empty list.") | |||||
val_list_type = set([type(_) for _ in val]) # type check | |||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||||
# up-cast int to float | |||||
val_type = float | |||||
elif len(val_list_type) == 1: | |||||
val_type = val_list_type.pop() | |||||
raise ValueError("Cannot append an empty list.") | |||||
if self.content_dim == 2 and self._1d_list_check(val): | |||||
# 1维list检查 | |||||
pass | |||||
elif self.content_dim == 3 and self._2d_list_check(val): | |||||
# 2维list检查 | |||||
pass | |||||
else: | else: | ||||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||||
raise RuntimeError( | |||||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||||
# scalar检查 | |||||
if type(val) == float and self.pytype == int: | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
else: | else: | ||||
if self.is_2d_list is True: | |||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||||
if val_type == float and self.pytype == int: | |||||
# up-cast | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
elif val_type == int and self.pytype == float: | |||||
pass | |||||
elif val_type == self.pytype: | |||||
pass | |||||
else: | |||||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
@@ -149,28 +320,44 @@ class FieldArray(object): | |||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices): | |||||
def get(self, indices, pad=True): | |||||
"""Fetch instances based on indices. | """Fetch instances based on indices. | ||||
:param indices: an int, or a list of int. | :param indices: an int, or a list of int. | ||||
:param pad: bool, 是否对返回的结果进行padding。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
return self.content[indices] | return self.content[indices] | ||||
if self.is_input is False and self.is_target is False: | if self.is_input is False and self.is_target is False: | ||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | ||||
batch_size = len(indices) | |||||
if not is_iterable(self.content[0]): | |||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||||
elif self.dtype in (np.int64, np.float64): | |||||
max_len = max([len(self.content[i]) for i in indices]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||||
for i, idx in enumerate(indices): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
else: # should only be str | |||||
array = np.array([self.content[i] for i in indices]) | |||||
return array | |||||
contents = [self.content[i] for i in indices] | |||||
if self.padder is None or pad is False: | |||||
return np.array(contents) | |||||
else: | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | |||||
def set_padder(self, padder): | |||||
""" | |||||
设置padding方式 | |||||
:param padder: PadderBase类型或None. 设置为None即删除padder. | |||||
:return: | |||||
""" | |||||
if padder is not None: | |||||
assert isinstance(padder, PadderBase), "padder must be of type PadderBase." | |||||
self.padder = padder | |||||
def set_pad_val(self, pad_val): | |||||
""" | |||||
修改padder的pad_val. | |||||
:param pad_val: int。 | |||||
:return: | |||||
""" | |||||
if self.padder is not None: | |||||
self.padder.set_pad_val(pad_val) | |||||
def __len__(self): | def __len__(self): | ||||
"""Returns the size of FieldArray. | """Returns the size of FieldArray. | ||||
@@ -186,3 +373,80 @@ def is_iterable(content): | |||||
except TypeError: | except TypeError: | ||||
return False | return False | ||||
return True | return True | ||||
class EngChar2DPadder(PadderBase): | |||||
""" | |||||
用于为英语执行character级别的2D padding操作。对应的field内容应该为[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']](这里为 | |||||
了更直观,把它们写为str,但实际使用时它们应该是character的index)。 | |||||
padded过后的batch内容,形状为(batch_size, max_sentence_length, max_word_length). max_sentence_length最大句子长度。 | |||||
max_word_length最长的word的长度 | |||||
""" | |||||
def __init__(self, pad_val=0, pad_length=0): | |||||
""" | |||||
:param pad_val: int, padding的位置使用该index | |||||
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截 | |||||
取到该长度. | |||||
""" | |||||
super().__init__(pad_val=pad_val) | |||||
self.pad_length = pad_length | |||||
def _exactly_three_dims(self, contents, field_name): | |||||
""" | |||||
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | |||||
:param contents: | |||||
:param field_name: str | |||||
:return: | |||||
""" | |||||
if not isinstance(contents, list): | |||||
raise TypeError("contents should be a list, not {}.".format(type(contents))) | |||||
value = contents[0] | |||||
try: | |||||
value = value[0] | |||||
except: | |||||
raise ValueError("Field:{} only has one dimension.".format(field_name)) | |||||
try: | |||||
value = value[0] | |||||
except: | |||||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | |||||
if is_iterable(value): | |||||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | |||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
""" | |||||
期望输入类似于 | |||||
[ | |||||
[[0, 2], [2, 3, 4], ..], | |||||
[[9, 8, 2, 4], [1, 2,], ...], | |||||
.... | |||||
] | |||||
:param contents: | |||||
:param field_name: | |||||
:param field_ele_dtype | |||||
:return: | |||||
""" | |||||
if field_ele_dtype not in (np.int64, np.float64): | |||||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | |||||
field_name, field_ele_dtype | |||||
)) | |||||
self._exactly_three_dims(contents, field_name) | |||||
if self.pad_length < 1: | |||||
max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents])) | |||||
else: | |||||
max_char_length = self.pad_length | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | |||||
batch_size = len(contents) | |||||
dtype = type(contents[0][0][0]) | |||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | |||||
dtype=dtype) | |||||
for b_idx, word_lst in enumerate(contents): | |||||
for c_idx, char_lst in enumerate(word_lst): | |||||
chars = char_lst[:max_char_length] | |||||
padded_array[b_idx, c_idx, :len(chars)] = chars | |||||
return padded_array |
@@ -11,6 +11,10 @@ class Instance(object): | |||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
""" | |||||
:param fields: 可能是一维或者二维的 list or np.array | |||||
""" | |||||
self.fields = fields | self.fields = fields | ||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
@@ -32,5 +36,5 @@ class Instance(object): | |||||
def __repr__(self): | def __repr__(self): | ||||
s = '\'' | s = '\'' | ||||
return "{" + ",\n".join( | return "{" + ",\n".join( | ||||
"\'" + field_name + "\': " + str(self.fields[field_name]) +\ | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" |
@@ -1,7 +1,11 @@ | |||||
from collections import defaultdict | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core import Batch | |||||
from fastNLP.core import DataSet | |||||
from fastNLP.core import SequentialSampler | |||||
from fastNLP.core.utils import _build_args | |||||
class Predictor(object): | class Predictor(object): | ||||
@@ -13,37 +17,55 @@ class Predictor(object): | |||||
Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, network): | |||||
if not isinstance(network, torch.nn.Module): | |||||
raise ValueError( | |||||
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) | |||||
self.network = network | |||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def predict(self, network, data): | |||||
def predict(self, data, seq_len_field_name=None): | |||||
"""Perform inference using the trained model. | """Perform inference using the trained model. | ||||
:param network: a PyTorch model (cpu) | |||||
:param data: a DataSet object. | :param data: a DataSet object. | ||||
:param str seq_len_field_name: field name indicating sequence lengths | |||||
:return: list of batch outputs | :return: list of batch outputs | ||||
""" | """ | ||||
# turn on the testing mode; clean up the history | |||||
self.mode(network, test=True) | |||||
batch_output = [] | |||||
if not isinstance(data, DataSet): | |||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
self.network.eval() | |||||
batch_output = defaultdict(list) | |||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||||
prefetch=False) | |||||
for batch_x, _ in data_iterator: | |||||
with torch.no_grad(): | |||||
prediction = self.data_forward(network, batch_x) | |||||
batch_output.append(prediction) | |||||
if hasattr(self.network, "predict"): | |||||
predict_func = self.network.predict | |||||
else: | |||||
predict_func = self.network.forward | |||||
return batch_output | |||||
with torch.no_grad(): | |||||
for batch_x, _ in data_iterator: | |||||
refined_batch_x = _build_args(predict_func, **batch_x) | |||||
prediction = predict_func(**refined_batch_x) | |||||
def mode(self, network, test=True): | |||||
if test: | |||||
network.eval() | |||||
else: | |||||
network.train() | |||||
if seq_len_field_name is not None: | |||||
seq_lens = batch_x[seq_len_field_name].tolist() | |||||
for key, value in prediction.items(): | |||||
value = value.cpu().numpy() | |||||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||||
batch_output[key].extend(value.tolist()) | |||||
else: | |||||
if seq_len_field_name is not None: | |||||
tmp_batch = [] | |||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
else: | |||||
batch_output[key].append(value) | |||||
def data_forward(self, network, x): | |||||
"""Forward through network.""" | |||||
y = network(**x) | |||||
return y | |||||
return batch_output |
@@ -5,7 +5,6 @@ from datetime import timedelta | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | |||||
from torch import nn | from torch import nn | ||||
try: | try: | ||||
@@ -14,7 +13,7 @@ except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | from fastNLP.core.utils import pseudo_tqdm as tqdm | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager | |||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
@@ -34,8 +33,8 @@ from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | ||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, | |||||
callbacks=None): | |||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True, | |||||
use_cuda=False, callbacks=None): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
@@ -46,20 +45,23 @@ class Trainer(object): | |||||
:param int print_every: step interval to print next training information. Default: -1(no print). | :param int print_every: step interval to print next training information. Default: -1(no print). | ||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | ||||
:param DataSet dev_data: the validation data | :param DataSet dev_data: the validation data | ||||
:param bool use_cuda: whether to use CUDA in training. | |||||
:param str save_path: file path to save models | :param str save_path: file path to save models | ||||
:param Optimizer optimizer: an optimizer object | :param Optimizer optimizer: an optimizer object | ||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ | :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.\\ | ||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | ||||
it will raise error if some field are not used. | |||||
it will raise error if some field are not used. 检查的原理是通过使用很小的batch(默认两个sample)来检查代码是 | |||||
否能够运行,但是这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个 | |||||
固定值的情况;(2)模型中存在累加前向计算次数的,可能会多计算几次。以上情况建议将check_code_level设置为-1 | |||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | ||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | ||||
smaller, add "-" in front of the string. For example:: | smaller, add "-" in front of the string. For example:: | ||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | metric_key="-PPL" # language model gets better as perplexity gets smaller | ||||
:param BaseSampler sampler: method used to generate batch data. | :param BaseSampler sampler: method used to generate batch data. | ||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。 | |||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||||
通过callback机制实现。 | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -114,7 +116,11 @@ class Trainer(object): | |||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | self.validate_every = int(validate_every) if validate_every!=0 else -1 | ||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.best_dev_epoch = None | |||||
self.best_dev_step = None | |||||
self.best_dev_perf = None | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.prefetch = prefetch | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
@@ -175,32 +181,26 @@ class Trainer(object): | |||||
""" | """ | ||||
results = {} | results = {} | ||||
if self.n_epochs <= 0: | |||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | try: | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self._model_device = self.model.parameters().__next__().device | self._model_device = self.model.parameters().__next__().device | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
if self.save_path is None: | |||||
class psudoSW: | |||||
def __getattr__(self, item): | |||||
def pass_func(*args, **kwargs): | |||||
pass | |||||
return pass_func | |||||
self._summary_writer = psudoSW() | |||||
else: | |||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
self.callback_manager.before_train() | |||||
self._train() | |||||
self.callback_manager.after_train(self.model) | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end(self.model) | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e, self.model) | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -216,8 +216,7 @@ class Trainer(object): | |||||
else: | else: | ||||
print("Fail to reload best model.") | print("Fail to reload best model.") | ||||
finally: | finally: | ||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
@@ -229,42 +228,36 @@ class Trainer(object): | |||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
start = time.time() | start = time.time() | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||||
total_steps = data_iterator.num_batches * self.n_epochs | |||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
self.callback_manager.before_epoch(epoch, self.n_epochs) | |||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
indices = data_iterator.get_batch_indices() | indices = data_iterator.get_batch_indices() | ||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.before_batch(batch_x, batch_y, indices) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | # edit prediction | ||||
self.callback_manager.before_loss(batch_y, prediction) | |||||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.before_backward(loss, self.model) | |||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
# gradient clipping | |||||
self.callback_manager.after_backward(self.model) | |||||
self.callback_manager.on_backward_end(self.model) | |||||
self._update() | self._update() | ||||
# lr scheduler; lr_finder; one_cycle | |||||
self.callback_manager.after_step(self.optimizer) | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
self.callback_manager.on_step_end(self.optimizer) | |||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | ||||
@@ -277,11 +270,10 @@ class Trainer(object): | |||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.step += 1 | self.step += 1 | ||||
# do nothing | |||||
self.callback_manager.after_batch() | |||||
self.callback_manager.on_batch_end() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \ | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
@@ -289,35 +281,29 @@ class Trainer(object): | |||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str) | pbar.write(eval_str) | ||||
# if self.validate_every < 0 and self.dev_data: | |||||
# eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
# self.tester._format_eval_results(eval_res) | |||||
# pbar.write(eval_str) | |||||
if epoch != self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
# ================= mini-batch end ==================== # | |||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) | |||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
# =============== epochs end =================== # | |||||
pbar.close() | pbar.close() | ||||
# ============ tqdm end ============== # | |||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
self.callback_manager.on_valid_begin() | |||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, metric in res.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.step) | |||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | ||||
else: | else: | ||||
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} | |||||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.after_valid(res, self.metric_key, self.optimizer) | |||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -365,12 +351,23 @@ class Trainer(object): | |||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | |||||
:param model: | |||||
:param model_name: | |||||
:param only_param: | |||||
:return: | |||||
""" | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_name = os.path.join(self.save_path, model_name) | |||||
model_path = os.path.join(self.save_path, model_name) | |||||
if only_param: | if only_param: | ||||
torch.save(model.state_dict(), model_name) | |||||
state_dict = model.state_dict() | |||||
for key in state_dict: | |||||
state_dict[key] = state_dict[key].cpu() | |||||
torch.save(state_dict, model_path) | |||||
else: | else: | ||||
torch.save(model, model_name) | |||||
model.cpu() | |||||
torch.save(model, model_path) | |||||
model.cuda() | |||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
@@ -186,11 +186,12 @@ def _check_function_or_method(func): | |||||
raise TypeError(f"{type(func)} is not a method or function.") | raise TypeError(f"{type(func)} is not a method or function.") | ||||
def _move_dict_value_to_device(*args, device: torch.device): | |||||
def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
""" | """ | ||||
move data to model's device, element in *args should be dict. This is a inplace change. | move data to model's device, element in *args should be dict. This is a inplace change. | ||||
:param device: torch.device | :param device: torch.device | ||||
:param non_blocking: bool, 是否异步将数据转移到cpu, 需要tensor使用pin_memory() | |||||
:param args: | :param args: | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -201,7 +202,7 @@ def _move_dict_value_to_device(*args, device: torch.device): | |||||
if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
for key, value in arg.items(): | for key, value in arg.items(): | ||||
if isinstance(value, torch.Tensor): | if isinstance(value, torch.Tensor): | ||||
arg[key] = value.to(device) | |||||
arg[key] = value.to(device, non_blocking=non_blocking) | |||||
else: | else: | ||||
raise TypeError("Only support `dict` type right now.") | raise TypeError("Only support `dict` type right now.") | ||||
@@ -11,18 +11,24 @@ class BaseLoader(object): | |||||
@staticmethod | @staticmethod | ||||
def load_lines(data_path): | def load_lines(data_path): | ||||
"""按行读取,舍弃每行两侧空白字符,返回list of str | |||||
""" | |||||
with open(data_path, "r", encoding="utf=8") as f: | with open(data_path, "r", encoding="utf=8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
@classmethod | @classmethod | ||||
def load(cls, data_path): | def load(cls, data_path): | ||||
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||||
""" | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [[word for word in sent.strip()] for sent in text] | return [[word for word in sent.strip()] for sent in text] | ||||
@classmethod | @classmethod | ||||
def load_with_cache(cls, data_path, cache_path): | def load_with_cache(cls, data_path, cache_path): | ||||
"""缓存版的load | |||||
""" | |||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | ||||
with open(cache_path, 'rb') as f: | with open(cache_path, 'rb') as f: | ||||
return pickle.load(f) | return pickle.load(f) | ||||
@@ -11,7 +11,6 @@ class ConfigLoader(BaseLoader): | |||||
:param str data_path: path to the config | :param str data_path: path to the config | ||||
""" | """ | ||||
def __init__(self, data_path=None): | def __init__(self, data_path=None): | ||||
super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
if data_path is not None: | if data_path is not None: | ||||
@@ -30,7 +29,7 @@ class ConfigLoader(BaseLoader): | |||||
Example:: | Example:: | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
""" | """ | ||||
assert isinstance(sections, dict) | assert isinstance(sections, dict) | ||||
@@ -202,8 +201,6 @@ class ConfigSaver(object): | |||||
continue | continue | ||||
if '=' not in line: | if '=' not in line: | ||||
# log = create_logger(__name__, './config_saver.log') | |||||
# log.error("can NOT load config file [%s]" % self.file_path) | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | ||||
key = line.split('=', maxsplit=1)[0].strip() | key = line.split('=', maxsplit=1)[0].strip() | ||||
@@ -263,10 +260,6 @@ class ConfigSaver(object): | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if section_file[k] != section[k]: | if section_file[k] != section[k]: | ||||
# logger = create_logger(__name__, "./config_loader.log") | |||||
# logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
# section_name, self.file_path | |||||
# )) | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if not change_file: | if not change_file: | ||||
@@ -90,6 +90,7 @@ class NativeDataSetLoader(DataSetLoader): | |||||
"""A simple example of DataSetLoader | """A simple example of DataSetLoader | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(NativeDataSetLoader, self).__init__() | super(NativeDataSetLoader, self).__init__() | ||||
@@ -107,6 +108,7 @@ class RawDataSetLoader(DataSetLoader): | |||||
"""A simple example of raw data reader | """A simple example of raw data reader | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(RawDataSetLoader, self).__init__() | super(RawDataSetLoader, self).__init__() | ||||
@@ -124,8 +126,8 @@ class RawDataSetLoader(DataSetLoader): | |||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | ||||
class POSDataSetLoader(DataSetLoader): | |||||
"""Dataset Loader for a POS Tag dataset. | |||||
class DummyPOSReader(DataSetLoader): | |||||
"""A simple reader for a dummy POS tagging dataset. | |||||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | ||||
Col is the label. Different sentence are divided by an empty line. | Col is the label. Different sentence are divided by an empty line. | ||||
@@ -142,8 +144,9 @@ class POSDataSetLoader(DataSetLoader): | |||||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(POSDataSetLoader, self).__init__() | |||||
super(DummyPOSReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
""" | """ | ||||
@@ -191,16 +194,14 @@ class POSDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') | |||||
class TokenizeDataSetLoader(DataSetLoader): | |||||
class DummyCWSReader(DataSetLoader): | |||||
"""Load pku dataset for Chinese word segmentation. | |||||
""" | """ | ||||
Data set loader for tokenization data sets | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(TokenizeDataSetLoader, self).__init__() | |||||
super(DummyCWSReader, self).__init__() | |||||
def load(self, data_path, max_seq_len=32): | def load(self, data_path, max_seq_len=32): | ||||
"""Load pku dataset for Chinese word segmentation. | """Load pku dataset for Chinese word segmentation. | ||||
@@ -253,11 +254,11 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
class ClassDataSetLoader(DataSetLoader): | |||||
class DummyClassificationReader(DataSetLoader): | |||||
"""Loader for a dummy classification data set""" | """Loader for a dummy classification data set""" | ||||
def __init__(self): | def __init__(self): | ||||
super(ClassDataSetLoader, self).__init__() | |||||
super(DummyClassificationReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
assert os.path.exists(data_path) | assert os.path.exists(data_path) | ||||
@@ -268,7 +269,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
@staticmethod | @staticmethod | ||||
def parse(lines): | def parse(lines): | ||||
""" | |||||
"""每行第一个token是标签,其余是字/词;由空格分隔。 | |||||
:param lines: lines from dataset | :param lines: lines from dataset | ||||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | :return: list(list(list())): the three level of lists are words, sentence, and dataset | ||||
@@ -324,16 +325,11 @@ class ConllLoader(DataSetLoader): | |||||
pass | pass | ||||
class LMDataSetLoader(DataSetLoader): | |||||
"""Language Model Dataset Loader | |||||
This loader produces data for language model training in a supervised way. | |||||
That means it has X and Y. | |||||
class DummyLMReader(DataSetLoader): | |||||
"""A Dummy Language Model Dataset Reader | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(LMDataSetLoader, self).__init__() | |||||
super(DummyLMReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
if not os.path.exists(data_path): | if not os.path.exists(data_path): | ||||
@@ -361,19 +357,25 @@ class LMDataSetLoader(DataSetLoader): | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
"""人民日报数据集 | |||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(PeopleDailyCorpusLoader, self).__init__() | super(PeopleDailyCorpusLoader, self).__init__() | ||||
self.pos = True | |||||
self.ner = True | |||||
def load(self, data_path): | |||||
def load(self, data_path, pos=True, ner=True): | |||||
""" | |||||
:param str data_path: 数据路径 | |||||
:param bool pos: 是否使用词性标签 | |||||
:param bool ner: 是否使用命名实体标签 | |||||
:return: a DataSet object | |||||
""" | |||||
self.pos, self.ner = pos, ner | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
sents = f.readlines() | sents = f.readlines() | ||||
pos_tag_examples = [] | |||||
ner_examples = [] | |||||
examples = [] | |||||
for sent in sents: | for sent in sents: | ||||
if len(sent) <= 2: | if len(sent) <= 2: | ||||
continue | continue | ||||
@@ -407,40 +409,44 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
sent_ner.append(ner_tag) | sent_ner.append(ner_tag) | ||||
sent_pos_tag.append(pos) | sent_pos_tag.append(pos) | ||||
sent_words.append(token) | sent_words.append(token) | ||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | |||||
ner_examples.append([sent_words, sent_ner]) | |||||
# List[List[List[str], List[str]]] | |||||
# ner_examples not used | |||||
return self.convert(pos_tag_examples) | |||||
example = [sent_words] | |||||
if self.pos is True: | |||||
example.append(sent_pos_tag) | |||||
if self.ner is True: | |||||
example.append(sent_ner) | |||||
examples.append(example) | |||||
return self.convert(examples) | |||||
def convert(self, data): | def convert(self, data): | ||||
data_set = DataSet() | data_set = DataSet() | ||||
for item in data: | for item in data: | ||||
sent_words, sent_pos_tag = item[0], item[1] | |||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) | |||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len") | |||||
data_set.set_target("tags") | |||||
data_set.set_input("sent_words") | |||||
data_set.set_input("seq_len") | |||||
sent_words = item[0] | |||||
if self.pos is True and self.ner is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2]) | |||||
elif self.pos is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||||
elif self.ner is True: | |||||
instance = Instance(words=sent_words, ner=item[1]) | |||||
else: | |||||
instance = Instance(words=sent_words) | |||||
data_set.append(instance) | |||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
return data_set | return data_set | ||||
class Conll2003Loader(DataSetLoader): | class Conll2003Loader(DataSetLoader): | ||||
"""Self-defined loader of conll2003 dataset | |||||
"""Loader for conll2003 dataset | |||||
More information about the given dataset cound be found on | More information about the given dataset cound be found on | ||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(Conll2003Loader, self).__init__() | super(Conll2003Loader, self).__init__() | ||||
def load(self, dataset_path): | def load(self, dataset_path): | ||||
with open(dataset_path, "r", encoding="utf-8") as f: | with open(dataset_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
##Parse the dataset line by line | |||||
parsed_data = [] | parsed_data = [] | ||||
sentence = [] | sentence = [] | ||||
tokens = [] | tokens = [] | ||||
@@ -467,21 +473,20 @@ class Conll2003Loader(DataSetLoader): | |||||
lambda labels: labels[1], sample[1])) | lambda labels: labels[1], sample[1])) | ||||
label2_list = list(map( | label2_list = list(map( | ||||
lambda labels: labels[2], sample[1])) | lambda labels: labels[2], sample[1])) | ||||
dataset.append(Instance(token_list=sample[0], | |||||
label0_list=label0_list, | |||||
label1_list=label1_list, | |||||
label2_list=label2_list)) | |||||
dataset.append(Instance(tokens=sample[0], | |||||
pos=label0_list, | |||||
chucks=label1_list, | |||||
ner=label2_list)) | |||||
return dataset | return dataset | ||||
class SNLIDataSetLoader(DataSetLoader): | |||||
class SNLIDataSetReader(DataSetLoader): | |||||
"""A data set loader for SNLI data set. | """A data set loader for SNLI data set. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(SNLIDataSetLoader, self).__init__() | |||||
super(SNLIDataSetReader, self).__init__() | |||||
def load(self, path_list): | def load(self, path_list): | ||||
""" | """ | ||||
@@ -540,3 +545,298 @@ class SNLIDataSetLoader(DataSetLoader): | |||||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | ||||
data_set.set_target("truth") | data_set.set_target("truth") | ||||
return data_set | return data_set | ||||
class ConllCWSReader(object): | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.strip().split()) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_char_lst(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
例如:: | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
例如:: | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super(NaiveCWSReader, self).__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:param cut_long_sent: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', '')) == 0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ZhConllPOSReader(object): | |||||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word) == 1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word) > 1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word) - 2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
class ConllxDataLoader(object): | |||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||||
""" | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
ds = DataSet() | |||||
for example in data_list: | |||||
ds.append(Instance(words=example[0], | |||||
pos_tags=example[1], | |||||
heads=example[2], | |||||
labels=example[3])) | |||||
return ds | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -101,9 +101,12 @@ class EmbedLoader(BaseLoader): | |||||
""" | """ | ||||
if vocab is None: | if vocab is None: | ||||
raise RuntimeError("You must provide a vocabulary.") | raise RuntimeError("You must provide a vocabulary.") | ||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) | |||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32) | |||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | ||||
with open(emb_file, "r", encoding="utf-8") as f: | with open(emb_file, "r", encoding="utf-8") as f: | ||||
startline = f.readline() | |||||
if len(startline.split()) > 2: | |||||
f.seek(0) | |||||
for line in f: | for line in f: | ||||
word, vector = EmbedLoader.parse_glove_line(line) | word, vector = EmbedLoader.parse_glove_line(line) | ||||
if word in vocab: | if word in vocab: | ||||
@@ -0,0 +1,362 @@ | |||||
""" | |||||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||||
""" | |||||
import copy | |||||
import json | |||||
import math | |||||
import os | |||||
import torch | |||||
from torch import nn | |||||
CONFIG_FILE = 'bert_config.json' | |||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
def gelu(x): | |||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
def swish(x): | |||||
return x * torch.sigmoid(x) | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
class BertEmbeddings(nn.Module): | |||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
super(BertEmbeddings, self).__init__() | |||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
return embeddings | |||||
class BertSelfAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
super(BertSelfAttention, self).__init__() | |||||
if hidden_size % num_attention_heads != 0: | |||||
raise ValueError( | |||||
"The hidden size (%d) is not a multiple of the number of attention " | |||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward(self, hidden_states, attention_mask): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
mixed_key_layer = self.key(hidden_states) | |||||
mixed_value_layer = self.value(hidden_states) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
return context_layer | |||||
class BertSelfOutput(nn.Module): | |||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
super(BertSelfOutput, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
super(BertAttention, self).__init__() | |||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
def forward(self, input_tensor, attention_mask): | |||||
self_output = self.self(input_tensor, attention_mask) | |||||
attention_output = self.output(self_output, input_tensor) | |||||
return attention_output | |||||
class BertIntermediate(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
super(BertIntermediate, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class BertOutput(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
super(BertOutput, self).__init__() | |||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertLayer(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertLayer, self).__init__() | |||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
def forward(self, hidden_states, attention_mask): | |||||
attention_output = self.attention(hidden_states, attention_mask) | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertEncoder, self).__init__() | |||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||||
all_encoder_layers = [] | |||||
for layer_module in self.layer: | |||||
hidden_states = layer_module(hidden_states, attention_mask) | |||||
if output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
if not output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | |||||
class BertPooler(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(BertPooler, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
class BertModel(nn.Module): | |||||
"""Bidirectional Embedding Representations from Transformers. | |||||
If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
Construct a BERT model with pre-trained weights:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
""" | |||||
def __init__(self, vocab_size, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02, **kwargs): | |||||
super(BertModel, self).__init__() | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
self.apply(self.init_bert_weights) | |||||
def init_bert_weights(self, module): | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | |||||
module.bias.data.zero_() | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones_like(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
# We create a 3D attention mask from a 2D tensor mask. | |||||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||||
# this attention mask is more simple than the triangular masking of causal attention | |||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||||
# masked positions, this operation will create a tensor which is 0.0 for | |||||
# positions we want to attend and -10000.0 for masked positions. | |||||
# Since we are adding it to the raw scores before the softmax, this is | |||||
# effectively the same as removing these entirely. | |||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
encoded_layers = self.encoder(embedding_output, | |||||
extended_attention_mask, | |||||
output_all_encoded_layers=output_all_encoded_layers) | |||||
sequence_output = encoded_layers[-1] | |||||
pooled_output = self.pooler(sequence_output) | |||||
if not output_all_encoded_layers: | |||||
encoded_layers = encoded_layers[-1] | |||||
return encoded_layers, pooled_output | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
# Load config | |||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(*inputs, **config, **kwargs) | |||||
if state_dict is None: | |||||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||||
state_dict = torch.load(weights_path) | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
def load(module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||||
if len(missing_keys) > 0: | |||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | |||||
if len(unexpected_keys) > 0: | |||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | |||||
return model |
@@ -1,17 +1,20 @@ | |||||
import copy | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from collections import defaultdict | |||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.utils import seq_mask | |||||
from fastNLP.core.losses import LossFunc | from fastNLP.core.losses import LossFunc | ||||
from fastNLP.core.metrics import MetricBase | from fastNLP.core.metrics import MetricBase | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.utils import seq_mask | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -197,53 +200,64 @@ class BiaffineParser(GraphParser): | |||||
pos_vocab_size, | pos_vocab_size, | ||||
pos_emb_dim, | pos_emb_dim, | ||||
num_label, | num_label, | ||||
word_hid_dim=100, | |||||
pos_hid_dim=100, | |||||
rnn_layers=1, | rnn_layers=1, | ||||
rnn_hidden_size=200, | rnn_hidden_size=200, | ||||
arc_mlp_size=100, | arc_mlp_size=100, | ||||
label_mlp_size=100, | label_mlp_size=100, | ||||
dropout=0.3, | dropout=0.3, | ||||
use_var_lstm=False, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | ||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | ||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | ||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | ||||
self.word_norm = nn.LayerNorm(word_hid_dim) | self.word_norm = nn.LayerNorm(word_hid_dim) | ||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | self.pos_norm = nn.LayerNorm(pos_hid_dim) | ||||
self.use_var_lstm = use_var_lstm | |||||
if use_var_lstm: | |||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
self.encoder_name = encoder | |||||
self.max_len = 512 | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'transformer': | |||||
n_head = 16 | |||||
d_k = d_v = int(rnn_out_size / n_head) | |||||
if (d_k * n_head) != rnn_out_size: | |||||
raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||||
embedding_dim=rnn_out_size,) | |||||
self.encoder = TransformerEncoder(num_layers=rnn_layers, | |||||
model_size=rnn_out_size, | |||||
inner_size=1024, | |||||
key_size=d_k, | |||||
value_size=d_v, | |||||
num_head=n_head, | |||||
dropout=dropout,) | |||||
else: | else: | ||||
self.lstm = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
self.arc_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size), | |||||
nn.LayerNorm(arc_mlp_size), | |||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.ELU(), | nn.ELU(), | ||||
TimestepDropout(p=dropout),) | TimestepDropout(p=dropout),) | ||||
self.arc_dep_mlp = copy.deepcopy(self.arc_head_mlp) | |||||
self.label_head_mlp = nn.Sequential(nn.Linear(rnn_out_size, label_mlp_size), | |||||
nn.LayerNorm(label_mlp_size), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout),) | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
@@ -286,24 +300,27 @@ class BiaffineParser(GraphParser): | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
del word, pos | |||||
# lstm, extract features | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.lstm(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
# encoder, extract features | |||||
if self.encoder_name.endswith('lstm'): | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
else: | |||||
seq_range = torch.arange(seq_len, dtype=torch.long, device=x.device)[None,:] | |||||
x = x + self.position_emb(seq_range) | |||||
feat = self.encoder(x, mask.float()) | |||||
# for arc biaffine | # for arc biaffine | ||||
# mlp, reduce dim | # mlp, reduce dim | ||||
arc_dep = self.arc_dep_mlp(feat) | |||||
arc_head = self.arc_head_mlp(feat) | |||||
label_dep = self.label_dep_mlp(feat) | |||||
label_head = self.label_head_mlp(feat) | |||||
del feat | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
# biaffine arc classifier | # biaffine arc classifier | ||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | ||||
@@ -349,7 +366,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.clone() | _arc_pred = arc_pred.clone() | ||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | ||||
@@ -357,12 +374,11 @@ class BiaffineParser(GraphParser): | |||||
arc_loss = arc_logits[batch_index, child_index, arc_true] | arc_loss = arc_logits[batch_index, child_index, arc_true] | ||||
label_loss = label_logits[batch_index, child_index, label_true] | label_loss = label_logits[batch_index, child_index, label_true] | ||||
arc_loss = arc_loss[:, 1:] | |||||
label_loss = label_loss[:, 1:] | |||||
float_mask = mask[:, 1:].float() | |||||
arc_nll = -(arc_loss*float_mask).mean() | |||||
label_nll = -(label_loss*float_mask).mean() | |||||
byte_mask = flip_mask.byte() | |||||
arc_loss.masked_fill_(byte_mask, 0) | |||||
label_loss.masked_fill_(byte_mask, 0) | |||||
arc_nll = -arc_loss.mean() | |||||
label_nll = -label_loss.mean() | |||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, word_seq, pos_seq, seq_lens): | def predict(self, word_seq, pos_seq, seq_lens): | ||||
@@ -4,6 +4,7 @@ import torch | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -23,46 +24,89 @@ class Attention(torch.nn.Module): | |||||
class DotAtte(nn.Module): | class DotAtte(nn.Module): | ||||
def __init__(self, key_size, value_size): | |||||
def __init__(self, key_size, value_size, dropout=0.1): | |||||
super(DotAtte, self).__init__() | super(DotAtte, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | |||||
self.softmax = nn.Softmax(dim=2) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
def forward(self, Q, K, V, mask_out=None): | |||||
""" | """ | ||||
:param Q: [batch, seq_len, key_size] | :param Q: [batch, seq_len, key_size] | ||||
:param K: [batch, seq_len, key_size] | :param K: [batch, seq_len, key_size] | ||||
:param V: [batch, seq_len, value_size] | :param V: [batch, seq_len, value_size] | ||||
:param seq_mask: [batch, seq_len] | |||||
:param mask_out: [batch, seq_len] | |||||
""" | """ | ||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
if mask_out is not None: | |||||
output.masked_fill_(mask_out, -float('inf')) | |||||
output = self.softmax(output) | |||||
output = self.drop(output) | |||||
return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
class MultiHeadAtte(nn.Module): | class MultiHeadAtte(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||||
""" | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
super(MultiHeadAtte, self).__init__() | super(MultiHeadAtte, self).__init__() | ||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) | |||||
self.input_size = input_size | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.num_head = num_head | |||||
in_size = key_size * num_head | |||||
self.q_in = nn.Linear(input_size, in_size) | |||||
self.k_in = nn.Linear(input_size, in_size) | |||||
self.v_in = nn.Linear(input_size, in_size) | |||||
self.attention = DotAtte(key_size=key_size, value_size=value_size) | |||||
self.out = nn.Linear(value_size * num_head, input_size) | |||||
self.drop = TimestepDropout(dropout) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
sqrt = math.sqrt | |||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) | |||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) | |||||
nn.init.xavier_normal_(self.out.weight) | |||||
def forward(self, Q, K, V, atte_mask_out=None): | |||||
""" | |||||
:param Q: [batch, seq_len, model_size] | |||||
:param K: [batch, seq_len, model_size] | |||||
:param V: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
batch, seq_len, _ = Q.size() | |||||
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||||
# input linear | |||||
q = self.q_in(Q).view(batch, seq_len, n_head, d_k) | |||||
k = self.k_in(K).view(batch, seq_len, n_head, d_k) | |||||
v = self.v_in(V).view(batch, seq_len, n_head, d_k) | |||||
# transpose q, k and v to do batch attention | |||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) | |||||
if atte_mask_out is not None: | |||||
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | |||||
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) | |||||
# concat all heads, do output linear | |||||
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) | |||||
output = self.drop(self.out(atte)) | |||||
return output | |||||
class Bi_Attention(nn.Module): | class Bi_Attention(nn.Module): | ||||
@@ -1,29 +1,56 @@ | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAtte | from ..aggregator.attention import MultiHeadAtte | ||||
from ..other_modules import LayerNormalization | |||||
from ..dropout import TimestepDropout | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||||
""" | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param inner_size: int, FFN层的hidden大小 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size) | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
nn.Linear(inner_size, model_size), | |||||
TimestepDropout(dropout),) | |||||
self.norm2 = nn.LayerNorm(model_size) | |||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||||
""" | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
:param input: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
:return: [batch, seq_len, model_size] | |||||
""" | |||||
attention = self.atte(input, input, input, atte_mask_out) | |||||
norm_atte = self.norm1(attention + input) | norm_atte = self.norm1(attention + input) | ||||
attention *= seq_mask | |||||
output = self.ffn(norm_atte) | output = self.ffn(norm_atte) | ||||
return self.norm2(output + norm_atte) | |||||
output = self.norm2(output + norm_atte) | |||||
output *= seq_mask | |||||
return output | |||||
def __init__(self, num_layers, **kargs): | def __init__(self, num_layers, **kargs): | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
return self.layers(x, seq_mask) | |||||
output = x | |||||
if seq_mask is None: | |||||
atte_mask_out = None | |||||
else: | |||||
atte_mask_out = (seq_mask < 1)[:,None,:] | |||||
seq_mask = seq_mask[:,:,None] | |||||
for layer in self.layers: | |||||
output = layer(output, seq_mask, atte_mask_out) | |||||
return output |
@@ -1,8 +1,9 @@ | |||||
[train] | [train] | ||||
n_epochs = 40 | |||||
n_epochs = 20 | |||||
batch_size = 32 | batch_size = 32 | ||||
use_cuda = true | use_cuda = true | ||||
validate_every = 500 | |||||
use_tqdm=true | |||||
validate_every = 1000 | |||||
use_golden_train=true | use_golden_train=true | ||||
[test] | [test] | ||||
@@ -16,20 +17,18 @@ use_cuda = true | |||||
[model] | [model] | ||||
word_vocab_size = -1 | word_vocab_size = -1 | ||||
word_emb_dim = 100 | |||||
word_emb_dim = 300 | |||||
pos_vocab_size = -1 | pos_vocab_size = -1 | ||||
pos_emb_dim = 100 | pos_emb_dim = 100 | ||||
word_hid_dim = 100 | |||||
pos_hid_dim = 100 | |||||
rnn_layers = 3 | rnn_layers = 3 | ||||
rnn_hidden_size = 400 | |||||
rnn_hidden_size = 256 | |||||
arc_mlp_size = 500 | arc_mlp_size = 500 | ||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | |||||
use_var_lstm=true | |||||
dropout = 0.3 | |||||
encoder="var-lstm" | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 3e-4 | |||||
lr = 2e-3 | |||||
;weight_decay = 3e-5 | ;weight_decay = 3e-5 |
@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp']) | |||||
import torch | import torch | ||||
import argparse | import argparse | ||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -4,25 +4,23 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
import fastNLP | import fastNLP | ||||
import torch | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | from fastNLP.io.config_io import ConfigLoader, ConfigSection | ||||
from fastNLP.io.model_io import ModelLoader | from fastNLP.io.model_io import ModelLoader | ||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader | |||||
from fastNLP.io.dataset_loader import ConllxDataLoader | |||||
from fastNLP.api.processor import * | from fastNLP.api.processor import * | ||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.core.callback import Callback | |||||
BOS = '<BOS>' | BOS = '<BOS>' | ||||
EOS = '<EOS>' | EOS = '<EOS>' | ||||
UNK = '<UNK>' | UNK = '<UNK>' | ||||
PAD = '<PAD>' | |||||
NUM = '<NUM>' | NUM = '<NUM>' | ||||
ENG = '<ENG>' | ENG = '<ENG>' | ||||
@@ -33,11 +31,11 @@ if len(os.path.dirname(__file__)) != 0: | |||||
def convert(data): | def convert(data): | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for sample in data: | for sample in data: | ||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + list(map(int, sample[2])) | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(words=word_seq, | |||||
word_seq = [BOS] + sample['words'] | |||||
pos_seq = [BOS] + sample['pos_tags'] | |||||
heads = [0] + sample['heads'] | |||||
head_tags = [BOS] + sample['labels'] | |||||
dataset.append(Instance(raw_words=word_seq, | |||||
pos=pos_seq, | pos=pos_seq, | ||||
gold_heads=heads, | gold_heads=heads, | ||||
arc_true=heads, | arc_true=heads, | ||||
@@ -50,24 +48,11 @@ def load(path): | |||||
return convert(data) | return convert(data) | ||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | |||||
# datadir = "/home/yfshao/UD_English-EWT" | |||||
# train_data_name = "en_ewt-ud-train.conllu" | |||||
# dev_data_name = "en_ewt-ud-dev.conllu" | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | |||||
# loader = ConlluDataLoader() | |||||
# datadir = '/home/yfshao/workdir/parser-data/' | |||||
# train_data_name = "train_ctb5.txt" | |||||
# dev_data_name = "dev_ctb5.txt" | |||||
# test_data_name = "test_ctb5.txt" | |||||
datadir = "/home/yfshao/workdir/ctb7.0/" | |||||
datadir = "/remote-home/yfshao/workdir/ctb9.0/" | |||||
train_data_name = "train.conllx" | train_data_name = "train.conllx" | ||||
dev_data_name = "dev.conllx" | dev_data_name = "dev.conllx" | ||||
test_data_name = "test.conllx" | test_data_name = "test.conllx" | ||||
# emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
emb_file_name = "/remote-home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
processed_datadir = './save' | processed_datadir = './save' | ||||
@@ -113,27 +98,23 @@ def update_v(vocab, data, field): | |||||
data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) | data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) | ||||
print('load raw data and preprocess') | |||||
# use pretrain embedding | # use pretrain embedding | ||||
word_v = Vocabulary() | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary() | |||||
word_v = Vocabulary(unknown=UNK, padding=PAD) | |||||
pos_v = Vocabulary(unknown=None, padding=PAD) | |||||
tag_v = Vocabulary(unknown=None, padding=None) | tag_v = Vocabulary(unknown=None, padding=None) | ||||
train_data = load(os.path.join(datadir, train_data_name)) | train_data = load(os.path.join(datadir, train_data_name)) | ||||
dev_data = load(os.path.join(datadir, dev_data_name)) | dev_data = load(os.path.join(datadir, dev_data_name)) | ||||
test_data = load(os.path.join(datadir, test_data_name)) | test_data = load(os.path.join(datadir, test_data_name)) | ||||
print(train_data[0]) | |||||
num_p = Num2TagProcessor('words', 'words') | |||||
print('load raw data and preprocess') | |||||
num_p = Num2TagProcessor(tag=NUM, field_name='raw_words', new_added_field_name='words') | |||||
for ds in (train_data, dev_data, test_data): | for ds in (train_data, dev_data, test_data): | ||||
num_p(ds) | num_p(ds) | ||||
update_v(word_v, train_data, 'words') | update_v(word_v, train_data, 'words') | ||||
update_v(pos_v, train_data, 'pos') | update_v(pos_v, train_data, 'pos') | ||||
update_v(tag_v, train_data, 'tags') | update_v(tag_v, train_data, 'tags') | ||||
print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v))) | print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v))) | ||||
# embed, _ = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) | |||||
# print(embed.size()) | |||||
# Model | # Model | ||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
@@ -141,7 +122,7 @@ model_args['pos_vocab_size'] = len(pos_v) | |||||
model_args['num_label'] = len(tag_v) | model_args['num_label'] = len(tag_v) | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | |||||
print(model) | |||||
word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | ||||
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | ||||
@@ -164,7 +145,6 @@ for ds in (train_data, dev_data, test_data): | |||||
if train_args['use_golden_train']: | if train_args['use_golden_train']: | ||||
train_data.set_input('gold_heads', flag=True) | train_data.set_input('gold_heads', flag=True) | ||||
train_args.data.pop('use_golden_train') | train_args.data.pop('use_golden_train') | ||||
ignore_label = pos_v['punct'] | |||||
print(test_data[0]) | print(test_data[0]) | ||||
print('train len {}'.format(len(train_data))) | print('train len {}'.format(len(train_data))) | ||||
@@ -172,44 +152,62 @@ print('dev len {}'.format(len(dev_data))) | |||||
print('test len {}'.format(len(test_data))) | print('test len {}'.format(len(test_data))) | ||||
def train(path): | def train(path): | ||||
# test saving pipeline | # test saving pipeline | ||||
save_pipe(path) | save_pipe(path) | ||||
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) | |||||
embed = torch.tensor(embed, dtype=torch.float32) | |||||
# Trainer | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path) | |||||
# model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v) | |||||
# embed = torch.tensor(embed, dtype=torch.float32) | |||||
# model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=True) | |||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | model.pos_embedding.weight.data[pos_v.padding_idx].fill_(0) | ||||
# try: | |||||
# ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
# print('model parameter loaded!') | |||||
# except Exception as _: | |||||
# print("No saved model. Continue.") | |||||
# pass | |||||
class MyCallback(Callback): | |||||
def on_step_end(self, optimizer): | |||||
step = self.trainer.step | |||||
# learning rate decay | |||||
if step > 0 and step % 1000 == 0: | |||||
for pg in optimizer.param_groups: | |||||
pg['lr'] *= 0.93 | |||||
print('decay lr to {}'.format([pg['lr'] for pg in optimizer.param_groups])) | |||||
if step == 3000: | |||||
# start training embedding | |||||
print('start training embedding at {}'.format(step)) | |||||
model = self.trainer.model | |||||
for m in model.modules(): | |||||
if isinstance(m, torch.nn.Embedding): | |||||
m.weight.requires_grad = True | |||||
# Start training | |||||
trainer.train() | |||||
print("Training finished!") | |||||
# Trainer | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path, | |||||
callbacks=[MyCallback()]) | |||||
# save pipeline | |||||
save_pipe(path) | |||||
print('pipe saved') | |||||
# Start training | |||||
try: | |||||
trainer.train() | |||||
print("Training finished!") | |||||
finally: | |||||
# save pipeline | |||||
save_pipe(path) | |||||
print('pipe saved') | |||||
def save_pipe(path): | def save_pipe(path): | ||||
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | ||||
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | ||||
pipe.add_processor(label_toword_p) | pipe.add_processor(label_toword_p) | ||||
torch.save(pipe, os.path.join(path, 'pipe.pkl')) | |||||
os.makedirs(path, exist_ok=True) | |||||
torch.save({'pipeline': pipe, | |||||
'names':['num word_idx pos_idx seq set_input model tag_to_word'.split()], | |||||
}, os.path.join(path, 'pipe.pkl')) | |||||
def test(path): | def test(path): | ||||
@@ -234,16 +232,11 @@ def test(path): | |||||
print("Testing Test data") | print("Testing Test data") | ||||
tester.test(model, test_data) | tester.test(model, test_data) | ||||
def build_pipe(parser_pipe_path): | |||||
parser_pipe = torch.load(parser_pipe_path) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer', 'save']) | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
parser.add_argument('--path', type=str, default='') | parser.add_argument('--path', type=str, default='') | ||||
# parser.add_argument('--dst', type=str, default='') | # parser.add_argument('--dst', type=str, default='') | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
@@ -253,12 +246,6 @@ if __name__ == "__main__": | |||||
test(args.path) | test(args.path) | ||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
pass | pass | ||||
# elif args.mode == 'save': | |||||
# print(f'save model from {args.path} to {args.dst}') | |||||
# save_model(args.path, args.dst) | |||||
# load_path = os.path.dirname(args.dst) | |||||
# print(f'save pipeline in {load_path}') | |||||
# build(load_path) | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |
@@ -1,34 +1,3 @@ | |||||
class ConllxDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
return list(filter(lambda x: x is not None, data)) | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
class MyDataloader: | class MyDataloader: | ||||
def load(self, data_path): | def load(self, data_path): | ||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
@@ -56,23 +25,3 @@ class MyDataloader: | |||||
return data | return data | ||||
def add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -0,0 +1,3 @@ | |||||
@@ -1,11 +1,11 @@ | |||||
from torch import nn | |||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
from fastNLP.modules.decoder.MLP import MLP | |||||
from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | class CWSBiLSTMEncoder(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, |
@@ -0,0 +1,125 @@ | |||||
""" | |||||
使用transformer作为分词的encoder端 | |||||
""" | |||||
from torch import nn | |||||
import torch | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField,seq_len_to_byte_mask | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
class TransformerCWS(nn.Module): | |||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4): | |||||
super().__init__() | |||||
self.embedding = nn.Embedding(vocab_num, embed_dim) | |||||
input_size = embed_dim | |||||
if bigram_vocab_num: | |||||
self.bigram_embedding = nn.Embedding(bigram_vocab_num, bigram_embed_dim) | |||||
input_size += num_bigram_per_char*bigram_embed_dim | |||||
self.drop = nn.Dropout(embed_drop_p, inplace=True) | |||||
self.fc1 = nn.Linear(input_size, hidden_size) | |||||
value_size = hidden_size//num_heads | |||||
self.transformer = TransformerEncoder(num_layers, model_size=hidden_size, inner_size=hidden_size, | |||||
key_size=value_size, | |||||
value_size=value_size, num_head=num_heads) | |||||
self.fc2 = nn.Linear(hidden_size, tag_size) | |||||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||||
allowed_transitions=allowed_trans) | |||||
def forward(self, chars, target, seq_lens, bigrams=None): | |||||
masks = seq_len_to_byte_mask(seq_lens).float() | |||||
x = self.embedding(chars) | |||||
batch_size = x.size(0) | |||||
length = x.size(1) | |||||
if hasattr(self, 'bigram_embedding'): | |||||
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size | |||||
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) | |||||
self.drop(x) | |||||
x = self.fc1(x) | |||||
feats = self.transformer(x, masks) | |||||
feats = self.fc2(feats) | |||||
losses = self.crf(feats, target, masks.float()) | |||||
pred_dict = {} | |||||
pred_dict['seq_lens'] = seq_lens | |||||
pred_dict['loss'] = torch.mean(losses) | |||||
return pred_dict | |||||
def predict(self, chars, seq_lens, bigrams=None): | |||||
masks = seq_len_to_byte_mask(seq_lens).float() | |||||
x = self.embedding(chars) | |||||
batch_size = x.size(0) | |||||
length = x.size(1) | |||||
if hasattr(self, 'bigram_embedding'): | |||||
bigrams = self.bigram_embedding(bigrams) # batch_size x seq_lens x per_char x embed_size | |||||
x = torch.cat([x, bigrams.view(batch_size, length, -1)], dim=-1) | |||||
self.drop(x) | |||||
x = self.fc1(x) | |||||
feats = self.transformer(x, masks) | |||||
feats = self.fc2(feats) | |||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
return {'pred': probs, 'seq_lens':seq_lens} | |||||
class NoamOpt(torch.optim.Optimizer): | |||||
"Optim wrapper that implements rate." | |||||
def __init__(self, model_size, factor, warmup, optimizer): | |||||
super().__init__([torch.nn.Parameter(torch.ones(1))], {}) | |||||
self.optimizer = optimizer | |||||
self._step = 0 | |||||
self.warmup = warmup | |||||
self.factor = factor | |||||
self.model_size = model_size | |||||
self._rate = 0 | |||||
def step(self, **kwargs): | |||||
"Update parameters and rate" | |||||
self._step += 1 | |||||
rate = self.rate() | |||||
for p in self.optimizer.param_groups: | |||||
p['lr'] = rate | |||||
self._rate = rate | |||||
self.optimizer.step() | |||||
def rate(self, step=None): | |||||
"Implement `lrate` above" | |||||
if step is None: | |||||
step = self._step | |||||
return self.factor * \ | |||||
(self.model_size ** (-0.5) * | |||||
min(step ** (-0.5), step * self.warmup ** (-1.5))) | |||||
if __name__ == '__main__': | |||||
transformer = TransformerCWS(10, embed_dim=100, bigram_vocab_num=10, bigram_embed_dim=100, num_bigram_per_char=8, | |||||
hidden_size=200, embed_drop_p=0.3, num_layers=1, num_heads=8, tag_size=4) | |||||
chars = torch.randint(10, size=(4, 7)).long() | |||||
bigrams = torch.randint(10, size=(4, 56)).long() | |||||
seq_lens = torch.ones(4).long()*7 | |||||
target = torch.randint(4, size=(4, 7)) | |||||
print(transformer(chars, target, seq_lens, bigrams)) | |||||
optimizer = torch.optim.Adam(transformer.parameters()) | |||||
opt = NoamOpt(10 ,1, 400, optimizer) |
@@ -4,7 +4,7 @@ import re | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||||
from reproduction.Chinese_word_segmentation.process.span_converter import SpanConverter | |||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | ||||
@@ -226,109 +226,6 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||||
return bigrams | return bigrams | ||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | |||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | |||||
# Processor了 | |||||
# TODO 如何将建立vocab和index这两步统一了? | |||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose =verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets)==0 and not hasattr(self,'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets)!=0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
def __init__(self, field_name, min_freq=1, max_size=None): | def __init__(self, field_name, min_freq=1, max_size=None): | ||||
@@ -4,7 +4,7 @@ from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.core.utils import ClassPreprocess as Preprocess | from fastNLP.core.utils import ClassPreprocess as Preprocess | ||||
from fastNLP.io.config_io import ConfigLoader | from fastNLP.io.config_io import ConfigLoader | ||||
from fastNLP.io.config_io import ConfigSection | from fastNLP.io.config_io import ConfigSection | ||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
@@ -0,0 +1,29 @@ | |||||
from fastNLP.io.dataset_loader import ZhConllPOSReader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
if __name__ == '__main__': | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print(d) |
@@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' | |||||
[model] | [model] | ||||
rnn_hidden_units = 300 | rnn_hidden_units = 300 | ||||
word_emb_dim = 100 | |||||
word_emb_dim = 300 | |||||
dropout = 0.5 | dropout = 0.5 | ||||
use_crf = true | use_crf = true | ||||
print_every_step = 10 | print_every_step = 10 |
@@ -0,0 +1,163 @@ | |||||
import argparse | |||||
import os | |||||
import pickle | |||||
import sys | |||||
import torch | |||||
# in order to run fastNLP without installation | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInputProcessor, IndexerProcessor | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.io.dataset_loader import ConllxDataLoader | |||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||||
cfgfile = './pos_tag.cfg' | |||||
pickle_path = "save" | |||||
def load_tencent_embed(embed_path, word2id): | |||||
hit = 0 | |||||
with open(embed_path, "rb") as f: | |||||
embed_dict = pickle.load(f) | |||||
embedding_tensor = torch.randn(len(word2id), 200) | |||||
for key in word2id: | |||||
if key in embed_dict: | |||||
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) | |||||
hit += 1 | |||||
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) | |||||
return embedding_tensor | |||||
def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||||
# load config | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | |||||
print("loading training set...") | |||||
dataset = ConllxDataLoader().load(train_data_path, return_dataset=True) | |||||
print("loading dev set...") | |||||
dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True) | |||||
print(dataset) | |||||
print("================= dataset ready =====================") | |||||
dataset.rename_field("tag", "truth") | |||||
dev_data.rename_field("tag", "truth") | |||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||||
tag_proc = VocabIndexerProcessor("truth", is_input=True) | |||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len") | |||||
vocab_proc(dataset) | |||||
tag_proc(dataset) | |||||
seq_len_proc(dataset) | |||||
# index dev set | |||||
word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab | |||||
dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]], new_field_name="word_seq") | |||||
dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]], new_field_name="truth") | |||||
dev_data.apply(lambda ins: len(ins["word_seq"]), new_field_name="word_seq_origin_len") | |||||
# set input & target | |||||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
dev_data.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
dataset.set_target("truth", "word_seq_origin_len") | |||||
dev_data.set_target("truth", "word_seq_origin_len") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||||
# define a model | |||||
if checkpoint is None: | |||||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||||
pre_trained = None | |||||
model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained) | |||||
print(model) | |||||
else: | |||||
model = torch.load(checkpoint) | |||||
# call trainer to train | |||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dev_data, metric_key="f", | |||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||||
trainer.train(load_best_model=True) | |||||
# save model & pipeline | |||||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||||
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) | |||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||||
torch.save(save_dict, os.path.join(save, "model_pp.pkl")) | |||||
print("pipeline saved") | |||||
def run_test(test_path): | |||||
test_data = ConllxDataLoader().load(test_path, return_dataset=True) | |||||
with open("model_pp_0117.pkl", "rb") as f: | |||||
save_dict = torch.load(f) | |||||
tag_vocab = save_dict["tag_vocab"] | |||||
pipeline = save_dict["pipeline"] | |||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||||
pipeline(test_data) | |||||
test_data.set_target("truth") | |||||
prediction = test_data.field_arrays["predict"].content | |||||
truth = test_data.field_arrays["truth"].content | |||||
seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||||
# padding by hand | |||||
max_length = max([len(seq) for seq in prediction]) | |||||
for idx in range(len(prediction)): | |||||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||||
seq_lens="word_seq_origin_len") | |||||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||||
{"truth": torch.Tensor(truth)}) | |||||
test_result = evaluator.get_metric() | |||||
f1 = round(test_result['f'] * 100, 2) | |||||
pre = round(test_result['pre'] * 100, 2) | |||||
rec = round(test_result['rec'] * 100, 2) | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx") | |||||
parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx") | |||||
parser.add_argument("--test", type=str, help="test conll file", default=None) | |||||
parser.add_argument("--save", type=str, help="path to save", default=None) | |||||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||||
args = parser.parse_args() | |||||
if args.test is not None: | |||||
print(run_test(args.test)) | |||||
else: | |||||
if args.restart is True: | |||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||||
if args.checkpoint is None: | |||||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||||
train(args.train, args.dev, args.checkpoint, save=args.save) | |||||
else: | |||||
# 一次训练 python train_pos_tag.py | |||||
train(args.train, args.dev, save=args.save) |
@@ -1,197 +0,0 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class NaiveCWSReader(DataSetLoader): | |||||
""" | |||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
或者,即每个part后面还有一个pos tag | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
""" | |||||
允许使用的情况有(默认以\t或空格作为seg) | |||||
这是 fastNLP , 一个 非常 good 的 包 . | |||||
和 | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||||
:param filepath: | |||||
:param in_word_splitter: | |||||
:return: | |||||
""" | |||||
if in_word_splitter == None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line.replace(' ', ''))==0: # 不能接受空行 | |||||
continue | |||||
if not in_word_splitter is None: | |||||
words = [] | |||||
for part in line.split(): | |||||
word = part.split(in_word_splitter)[0] | |||||
words.append(word) | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
return dataset | |||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words)==0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class ConllCWSReader(object): | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_char_lst(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
@@ -1,151 +0,0 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
from fastNLP.core.utils import load_pickle | |||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.utils import save_pickle | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
# not in the file's dir | |||||
if len(os.path.dirname(__file__)) != 0: | |||||
os.chdir(os.path.dirname(__file__)) | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './cws.cfg' | |||||
cws_data_path = os.path.join(datadir, "pku_training.utf8") | |||||
pickle_path = "save" | |||||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
print('model loaded!') | |||||
except Exception as e: | |||||
print('cannot load model!') | |||||
raise | |||||
# Data Loader | |||||
infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
print('data loaded') | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | |||||
print(results) | |||||
print("Inference finished!") | |||||
def train(): | |||||
# Config Loader | |||||
train_args = ConfigSection() | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
print("loading data set...") | |||||
data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
data.load(cws_data_path) | |||||
data_train, data_dev = data.split(ratio=0.3) | |||||
train_args["vocab_size"] = len(data.word_vocab) | |||||
train_args["num_classes"] = len(data.label_vocab) | |||||
print("vocab size={}, num_classes={}".format(len(data.word_vocab), len(data.label_vocab))) | |||||
change_field_is_target(data_dev, "truth", True) | |||||
save_pickle(data_dev, "./save/", "data_dev.pkl") | |||||
save_pickle(data.word_vocab, "./save/", "word2id.pkl") | |||||
save_pickle(data.label_vocab, "./save/", "label2id.pkl") | |||||
# Trainer | |||||
trainer = SeqLabelTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"], | |||||
validate=train_args["validate"], | |||||
use_cuda=train_args["use_cuda"], pickle_path=train_args["pickle_path"], | |||||
save_best_dev=True, print_every_step=10, model_name="trained_model.pkl", | |||||
evaluator=SeqLabelEvaluator()) | |||||
# Model | |||||
model = AdvSeqLabel(train_args) | |||||
try: | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
print('model parameter loaded!') | |||||
except Exception as e: | |||||
print("No saved model. Continue.") | |||||
pass | |||||
# Start training | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | |||||
# Saver | |||||
saver = ModelSaver("./save/trained_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
def predict(): | |||||
# Config Loader | |||||
test_args = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | |||||
model = AdvSeqLabel(test_args) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
print("model loaded!") | |||||
# Tester | |||||
test_args["evaluator"] = SeqLabelEvaluator() | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | |||||
tester.test(model, dev_data) | |||||
if __name__ == "__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
predict() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() |
@@ -1,153 +0,0 @@ | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence | |||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word)==1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
class ZhConllPOSReader(object): | |||||
# 中文colln格式reader | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word)==1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample)==0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
if __name__ == '__main__': | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | |||||
print(d) |
@@ -1,113 +0,0 @@ | |||||
import argparse | |||||
import os | |||||
import pickle | |||||
import sys | |||||
import torch | |||||
# in order to run fastNLP without installation | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import SeqLenProcessor | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||||
cfgfile = './pos_tag.cfg' | |||||
pickle_path = "save" | |||||
def load_tencent_embed(embed_path, word2id): | |||||
hit = 0 | |||||
with open(embed_path, "rb") as f: | |||||
embed_dict = pickle.load(f) | |||||
embedding_tensor = torch.randn(len(word2id), 200) | |||||
for key in word2id: | |||||
if key in embed_dict: | |||||
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) | |||||
hit += 1 | |||||
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) | |||||
return embedding_tensor | |||||
def train(checkpoint=None): | |||||
# load config | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | |||||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||||
print(dataset) | |||||
print("dataset transformed") | |||||
dataset.rename_field("tag", "truth") | |||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||||
tag_proc = VocabIndexerProcessor("truth") | |||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||||
vocab_proc(dataset) | |||||
tag_proc(dataset) | |||||
seq_len_proc(dataset) | |||||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
dataset.set_target("truth", "word_seq_origin_len") | |||||
print("processors defined") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||||
# define a model | |||||
if checkpoint is None: | |||||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||||
pre_trained = None | |||||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) | |||||
print(model) | |||||
else: | |||||
model = torch.load(checkpoint) | |||||
# call trainer to train | |||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dataset, metric_key="f", | |||||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") | |||||
trainer.train(load_best_model=True) | |||||
# save model & pipeline | |||||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||||
torch.save(save_dict, "model_pp.pkl") | |||||
print("pipeline saved") | |||||
torch.save(model, "./save/best_model.pkl") | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||||
args = parser.parse_args() | |||||
if args.restart is True: | |||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||||
if args.checkpoint is None: | |||||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||||
train(args.checkpoint) | |||||
else: | |||||
# 一次训练 python train_pos_tag.py | |||||
train() |
@@ -1,9 +1,12 @@ | |||||
import random | import random | ||||
import unittest | import unittest | ||||
from fastNLP import Vocabulary | |||||
import numpy as np | |||||
from fastNLP import Vocabulary, Instance | |||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | ||||
IndexerProcessor, VocabProcessor, SeqLenProcessor | |||||
IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \ | |||||
SetInputProcessor, VocabIndexerProcessor | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -53,3 +56,46 @@ class TestProcessor(unittest.TestCase): | |||||
ds = proc(ds) | ds = proc(ds) | ||||
for data in ds.field_arrays["len"].content: | for data in ds.field_arrays["len"].content: | ||||
self.assertEqual(data, 30) | self.assertEqual(data, 30) | ||||
def test_ModelProcessor(self): | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
model = CNNText(100, 100, 5) | |||||
ins_list = [] | |||||
for _ in range(64): | |||||
seq_len = np.random.randint(5, 30) | |||||
ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len)) | |||||
data_set = DataSet(ins_list) | |||||
data_set.set_input("word_seq", "seq_lens") | |||||
proc = ModelProcessor(model) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("pred" in data_set) | |||||
def test_Index2WordProcessor(self): | |||||
vocab = Vocabulary() | |||||
vocab.add_word_lst(["a", "b", "c", "d", "e"]) | |||||
proc = Index2WordProcessor(vocab, "tag_id", "tag") | |||||
data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])]) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("tag" in data_set) | |||||
def test_SetTargetProcessor(self): | |||||
proc = SetTargetProcessor("a", "b", "c") | |||||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||||
data_set = proc(data_set) | |||||
self.assertTrue(data_set["a"].is_target) | |||||
self.assertTrue(data_set["b"].is_target) | |||||
self.assertTrue(data_set["c"].is_target) | |||||
def test_SetInputProcessor(self): | |||||
proc = SetInputProcessor("a", "b", "c") | |||||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||||
data_set = proc(data_set) | |||||
self.assertTrue(data_set["a"].is_input) | |||||
self.assertTrue(data_set["b"].is_input) | |||||
self.assertTrue(data_set["c"].is_input) | |||||
def test_VocabIndexerProcessor(self): | |||||
proc = VocabIndexerProcessor("word_seq", "word_ids") | |||||
data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])]) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("word_ids" in data_set) |
@@ -1,13 +1,44 @@ | |||||
import time | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.dataset import construct_dataset | from fastNLP.core.dataset import construct_dataset | ||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
def generate_fake_dataset(num_samples=1000): | |||||
""" | |||||
产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]} | |||||
:param num_samples: sample的数量 | |||||
:return: | |||||
""" | |||||
max_len = 50 | |||||
min_len = 10 | |||||
num_features = 4 | |||||
data_dict = {} | |||||
for i in range(num_features): | |||||
data = [] | |||||
lengths = np.random.randint(min_len, max_len, size=(num_samples)) | |||||
for length in lengths: | |||||
data.append(np.random.randint(100, size=length)) | |||||
data_dict[str(i)] = data | |||||
dataset = DataSet(data_dict) | |||||
for i in range(num_features): | |||||
if np.random.randint(2) == 0: | |||||
dataset.set_input(str(i)) | |||||
else: | |||||
dataset.set_target(str(i)) | |||||
return dataset | |||||
class TestCase1(unittest.TestCase): | class TestCase1(unittest.TestCase): | ||||
def test_simple(self): | def test_simple(self): | ||||
dataset = construct_dataset( | dataset = construct_dataset( | ||||
@@ -31,3 +62,116 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(len(y["y"]), 4) | self.assertEqual(len(y["y"]), 4) | ||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | ||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) | self.assertListEqual(list(y["y"][-1]), [5, 6]) | ||||
def test_list_padding(self): | |||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertEqual(x["x"].shape, (4, 4)) | |||||
self.assertEqual(y["y"].shape, (4, 4)) | |||||
def test_numpy_padding(self): | |||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | |||||
self.assertEqual(x["x"].shape, (4, 4)) | |||||
self.assertEqual(y["y"].shape, (4, 4)) | |||||
def test_list_to_tensor(self): | |||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||||
def test_numpy_to_tensor(self): | |||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||||
def test_list_of_list_to_tensor(self): | |||||
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | |||||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||||
def test_list_of_numpy_to_tensor(self): | |||||
ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | |||||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | |||||
print(x, y) | |||||
def test_sequential_batch(self): | |||||
batch_size = 32 | |||||
pause_seconds = 0.01 | |||||
num_samples = 1000 | |||||
dataset = generate_fake_dataset(num_samples) | |||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | |||||
time.sleep(pause_seconds) | |||||
""" | |||||
def test_multi_workers_batch(self): | |||||
batch_size = 32 | |||||
pause_seconds = 0.01 | |||||
num_samples = 1000 | |||||
dataset = generate_fake_dataset(num_samples) | |||||
num_workers = 1 | |||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers) | |||||
for batch_x, batch_y in batch: | |||||
time.sleep(pause_seconds) | |||||
num_workers = 2 | |||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers) | |||||
end1 = time.time() | |||||
for batch_x, batch_y in batch: | |||||
time.sleep(pause_seconds) | |||||
""" | |||||
""" | |||||
def test_pin_memory(self): | |||||
batch_size = 32 | |||||
pause_seconds = 0.01 | |||||
num_samples = 1000 | |||||
dataset = generate_fake_dataset(num_samples) | |||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) | |||||
# 这里发生OOM | |||||
# for batch_x, batch_y in batch: | |||||
# time.sleep(pause_seconds) | |||||
num_workers = 2 | |||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, | |||||
pin_memory=True) | |||||
# 这里发生OOM | |||||
# for batch_x, batch_y in batch: | |||||
# time.sleep(pause_seconds) | |||||
""" |
@@ -1,40 +1,47 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.core.callback import EchoCallback | |||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
LRFinder, \ | |||||
TensorboardCallback | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | from fastNLP.core.optimizer import SGD | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
class TestCallback(unittest.TestCase): | |||||
def test_case(self): | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
def prepare_env(): | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x") | |||||
data_set.set_target("y") | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x") | |||||
data_set.set_target("y") | |||||
model = NaiveClassifier(2, 1) | |||||
return data_set, model | |||||
model = NaiveClassifier(2, 1) | |||||
class TestCallback(unittest.TestCase): | |||||
def test_echo_callback(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=1, | |||||
n_epochs=2, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
@@ -42,3 +49,90 @@ class TestCallback(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
callbacks=[EchoCallback()]) | callbacks=[EchoCallback()]) | ||||
trainer.train() | trainer.train() | ||||
def test_gradient_clip(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||||
trainer.train() | |||||
def test_early_stop(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.01), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[EarlyStopCallback(5)]) | |||||
trainer.train() | |||||
def test_lr_scheduler(self): | |||||
data_set, model = prepare_env() | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=optimizer, | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||||
trainer.train() | |||||
def test_KeyBoardInterrupt(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[ControlC(False)]) | |||||
trainer.train() | |||||
def test_LRFinder(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)]) | |||||
trainer.train() | |||||
def test_TensorboardCallback(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[TensorboardCallback("loss", "metric")]) | |||||
trainer.train() |
@@ -6,15 +6,29 @@ from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
class TestDataSet(unittest.TestCase): | |||||
class TestDataSetInit(unittest.TestCase): | |||||
"""初始化DataSet的办法有以下几种: | |||||
1) 用dict: | |||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||||
2) 用list of Instance: | |||||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||||
只接受纯list或者最外层ndarray | |||||
""" | |||||
def test_init_v1(self): | def test_init_v1(self): | ||||
# 一维list | |||||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | ||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | ||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | ||||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | ||||
def test_init_v2(self): | def test_init_v2(self): | ||||
# 用dict | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | ||||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | ||||
@@ -28,6 +42,8 @@ class TestDataSet(unittest.TestCase): | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
_ = DataSet(0.00001) | _ = DataSet(0.00001) | ||||
class TestDataSetMethods(unittest.TestCase): | |||||
def test_append(self): | def test_append(self): | ||||
dd = DataSet() | dd = DataSet() | ||||
for _ in range(3): | for _ in range(3): | ||||
@@ -5,8 +5,65 @@ import numpy as np | |||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
class TestFieldArrayInit(unittest.TestCase): | |||||
""" | |||||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||||
然后后面的样本使用FieldArray.append进行添加。 | |||||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||||
""" | |||||
def test_init_v1(self): | |||||
# 二维list | |||||
fa = FieldArray("x", [[1, 2], [3, 4]] * 5, is_input=True) | |||||
def test_init_v2(self): | |||||
# 二维array | |||||
fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5), is_input=True) | |||||
def test_init_v3(self): | |||||
# 三维list | |||||
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | |||||
def test_init_v7(self): | |||||
# list of array | |||||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||||
self.assertEqual(fa.pytype, int) | |||||
self.assertEqual(fa.dtype, np.int) | |||||
def test_init_v4(self): | |||||
# 一维list | |||||
val = [1, 2, 3, 4] | |||||
fa = FieldArray("x", [val], is_input=True) | |||||
fa.append(val) | |||||
def test_init_v5(self): | |||||
# 一维array | |||||
val = np.array([1, 2, 3, 4]) | |||||
fa = FieldArray("x", [val], is_input=True) | |||||
fa.append(val) | |||||
def test_init_v6(self): | |||||
# 二维array | |||||
val = [[1, 2], [3, 4]] | |||||
fa = FieldArray("x", [val], is_input=True) | |||||
fa.append(val) | |||||
def test_init_v7(self): | |||||
# 二维list | |||||
val = np.array([[1, 2], [3, 4]]) | |||||
fa = FieldArray("x", [val], is_input=True) | |||||
fa.append(val) | |||||
class TestFieldArray(unittest.TestCase): | class TestFieldArray(unittest.TestCase): | ||||
def test(self): | |||||
def test_main(self): | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | ||||
self.assertEqual(len(fa), 5) | self.assertEqual(len(fa), 5) | ||||
fa.append(6) | fa.append(6) | ||||
@@ -42,13 +99,13 @@ class TestFieldArray(unittest.TestCase): | |||||
self.assertEqual(fa.pytype, str) | self.assertEqual(fa.pytype, str) | ||||
def test_support_np_array(self): | def test_support_np_array(self): | ||||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) | |||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
self.assertEqual(fa.pytype, float) | |||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | ||||
self.assertEqual(fa.dtype, np.ndarray) | |||||
self.assertEqual(fa.pytype, np.ndarray) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
self.assertEqual(fa.pytype, float) | |||||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | ||||
# in this case, pytype is actually a float. We do not care about it. | # in this case, pytype is actually a float. We do not care about it. | ||||
@@ -97,3 +154,65 @@ class TestFieldArray(unittest.TestCase): | |||||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | ||||
self.assertEqual(len(fa), 3) | self.assertEqual(len(fa), 3) | ||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | ||||
class TestPadder(unittest.TestCase): | |||||
def test01(self): | |||||
""" | |||||
测试AutoPadder能否正常工作 | |||||
:return: | |||||
""" | |||||
from fastNLP.core.fieldarray import AutoPadder | |||||
padder = AutoPadder() | |||||
content = ['This is a str', 'this is another str'] | |||||
self.assertListEqual(content, padder(content, None, np.str).tolist()) | |||||
content = [1, 2] | |||||
self.assertListEqual(content, padder(content, None, np.int64).tolist()) | |||||
content = [[1,2], [3], [4]] | |||||
self.assertListEqual([[1,2], [3, 0], [4, 0]], | |||||
padder(content, None, np.int64).tolist()) | |||||
content = [ | |||||
[[1, 2, 3], [4, 5], [7,8,9,10]], | |||||
[[1]] | |||||
] | |||||
self.assertListEqual(content, | |||||
padder(content, None, np.int64).tolist()) | |||||
def test02(self): | |||||
""" | |||||
测试EngChar2DPadder能不能正确使用 | |||||
:return: | |||||
""" | |||||
from fastNLP.core.fieldarray import EngChar2DPadder | |||||
padder = EngChar2DPadder(pad_length=0) | |||||
contents = [1, 2] | |||||
# 不能是1维 | |||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
contents = [[1, 2]] | |||||
# 不能是2维 | |||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
contents = [[[[1, 2]]]] | |||||
# 不能是3维以上 | |||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
contents = [ | |||||
[[1, 2, 3], [4, 5], [7,8,9,10]], | |||||
[[1]] | |||||
] | |||||
self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], | |||||
padder(contents, None, np.int64).tolist()) | |||||
padder = EngChar2DPadder(pad_length=5, pad_val=-100) | |||||
self.assertListEqual( | |||||
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | |||||
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | |||||
padder(contents, None, np.int64).tolist() | |||||
) |
@@ -1,4 +1,5 @@ | |||||
import unittest | import unittest | ||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -23,12 +24,26 @@ def prepare_fake_dataset(): | |||||
return data_set | return data_set | ||||
class LinearModel(torch.nn.Module): | |||||
def __init__(self): | |||||
super(LinearModel, self).__init__() | |||||
self.linear = Linear(2, 1) | |||||
def forward(self, x): | |||||
return {"predict": self.linear(x)} | |||||
class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
def test(self): | |||||
predictor = Predictor() | |||||
model = Linear(2, 1) | |||||
def test_simple(self): | |||||
model = LinearModel() | |||||
predictor = Predictor(model) | |||||
data = prepare_fake_dataset() | data = prepare_fake_dataset() | ||||
data.set_input("x") | data.set_input("x") | ||||
ans = predictor.predict(model, data) | |||||
self.assertEqual(len(ans), 2000) | |||||
self.assertTrue(isinstance(ans[0], torch.Tensor)) | |||||
ans = predictor.predict(data) | |||||
self.assertTrue(isinstance(ans, defaultdict)) | |||||
self.assertTrue("predict" in ans) | |||||
self.assertTrue(isinstance(ans["predict"], list)) | |||||
def test_sequence(self): | |||||
# test sequence input/output | |||||
pass |
@@ -237,6 +237,32 @@ class TrainerTestGround(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
print_every=2) | print_every=2) | ||||
def test_case2(self): | |||||
# check metrics Wrong | |||||
data_set = prepare_fake_dataset2('x1', 'x2') | |||||
""" | |||||
def test_trainer_multiprocess(self): | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=True, | |||||
print_every=2, | |||||
num_workers=2, | |||||
pin_memory=False, | |||||
timeout=0, | |||||
) | |||||
trainer.train() | |||||
""" |
@@ -1,2 +0,0 @@ | |||||
迈向充满希望的新世纪——一九九八年新年讲话 | |||||
(附图片1张) |
@@ -0,0 +1,100 @@ | |||||
1 上海 _ NR NR _ 3 nsubj _ _ | |||||
2 积极 _ AD AD _ 3 advmod _ _ | |||||
3 准备 _ VV VV _ 0 root _ _ | |||||
4 迎接 _ VV VV _ 3 ccomp _ _ | |||||
5 欧元 _ NN NN _ 6 nn _ _ | |||||
6 诞生 _ NN NN _ 4 dobj _ _ | |||||
1 新华社 _ NR NR _ 7 dep _ _ | |||||
2 上海 _ NR NR _ 7 dep _ _ | |||||
3 十二月 _ NT NT _ 7 dep _ _ | |||||
4 三十日 _ NT NT _ 7 dep _ _ | |||||
5 电 _ NN NN _ 7 dep _ _ | |||||
6 ( _ PU PU _ 7 punct _ _ | |||||
7 记者 _ NN NN _ 0 root _ _ | |||||
8 潘清 _ NR NR _ 7 dep _ _ | |||||
9 ) _ PU PU _ 7 punct _ _ | |||||
1 即将 _ AD AD _ 2 advmod _ _ | |||||
2 诞生 _ VV VV _ 4 rcmod _ _ | |||||
3 的 _ DEC DEC _ 2 cpm _ _ | |||||
4 欧元 _ NN NN _ 6 nsubj _ _ | |||||
5 , _ PU PU _ 6 punct _ _ | |||||
6 引起 _ VV VV _ 0 root _ _ | |||||
7 了 _ AS AS _ 6 asp _ _ | |||||
8 上海 _ NR NR _ 14 nn _ _ | |||||
9 这 _ DT DT _ 14 det _ _ | |||||
10 个 _ M M _ 9 clf _ _ | |||||
11 中国 _ NR NR _ 13 nn _ _ | |||||
12 金融 _ NN NN _ 13 nn _ _ | |||||
13 中心 _ NN NN _ 14 nn _ _ | |||||
14 城市 _ NN NN _ 16 assmod _ _ | |||||
15 的 _ DEG DEG _ 14 assm _ _ | |||||
16 关注 _ NN NN _ 6 dobj _ _ | |||||
17 。 _ PU PU _ 6 punct _ _ | |||||
1 上海 _ NR NR _ 2 nn _ _ | |||||
2 银行界 _ NN NN _ 4 nsubj _ _ | |||||
3 纷纷 _ AD AD _ 4 advmod _ _ | |||||
4 推出 _ VV VV _ 0 root _ _ | |||||
5 了 _ AS AS _ 4 asp _ _ | |||||
6 与 _ P P _ 8 prep _ _ | |||||
7 之 _ PN PN _ 6 pobj _ _ | |||||
8 相关 _ VA VA _ 15 rcmod _ _ | |||||
9 的 _ DEC DEC _ 8 cpm _ _ | |||||
10 外汇 _ NN NN _ 15 nn _ _ | |||||
11 业务 _ NN NN _ 15 nn _ _ | |||||
12 品种 _ NN NN _ 15 conj _ _ | |||||
13 和 _ CC CC _ 15 cc _ _ | |||||
14 服务 _ NN NN _ 15 nn _ _ | |||||
15 举措 _ NN NN _ 4 dobj _ _ | |||||
16 , _ PU PU _ 4 punct _ _ | |||||
17 积极 _ AD AD _ 18 advmod _ _ | |||||
18 准备 _ VV VV _ 4 dep _ _ | |||||
19 启动 _ VV VV _ 18 ccomp _ _ | |||||
20 欧元 _ NN NN _ 21 nn _ _ | |||||
21 业务 _ NN NN _ 19 dobj _ _ | |||||
22 。 _ PU PU _ 4 punct _ _ | |||||
1 一些 _ CD CD _ 8 nummod _ _ | |||||
2 热衷于 _ VV VV _ 8 rcmod _ _ | |||||
3 个人 _ NN NN _ 5 nn _ _ | |||||
4 外汇 _ NN NN _ 5 nn _ _ | |||||
5 交易 _ NN NN _ 2 dobj _ _ | |||||
6 的 _ DEC DEC _ 2 cpm _ _ | |||||
7 上海 _ NR NR _ 8 nn _ _ | |||||
8 市民 _ NN NN _ 13 nsubj _ _ | |||||
9 , _ PU PU _ 13 punct _ _ | |||||
10 也 _ AD AD _ 13 advmod _ _ | |||||
11 对 _ P P _ 13 prep _ _ | |||||
12 欧元 _ NN NN _ 11 pobj _ _ | |||||
13 表示 _ VV VV _ 0 root _ _ | |||||
14 出 _ VV VV _ 13 rcomp _ _ | |||||
15 极 _ AD AD _ 16 advmod _ _ | |||||
16 大 _ VA VA _ 18 rcmod _ _ | |||||
17 的 _ DEC DEC _ 16 cpm _ _ | |||||
18 兴趣 _ NN NN _ 13 dobj _ _ | |||||
19 。 _ PU PU _ 13 punct _ _ | |||||
1 继 _ P P _ 38 prep _ _ | |||||
2 上海 _ NR NR _ 6 nn _ _ | |||||
3 大众 _ NR NR _ 6 nn _ _ | |||||
4 汽车 _ NN NN _ 6 nn _ _ | |||||
5 有限 _ JJ JJ _ 6 amod _ _ | |||||
6 公司 _ NN NN _ 13 nsubj _ _ | |||||
7 十八日 _ NT NT _ 13 tmod _ _ | |||||
8 在 _ P P _ 13 prep _ _ | |||||
9 中国 _ NR NR _ 10 nn _ _ | |||||
10 银行 _ NN NN _ 12 nn _ _ | |||||
11 上海 _ NR NR _ 12 nn _ _ | |||||
12 分行 _ NN NN _ 8 pobj _ _ | |||||
13 开立 _ VV VV _ 19 lccomp _ _ | |||||
14 上海 _ NR NR _ 16 dep _ _ | |||||
15 第一 _ OD OD _ 16 ordmod _ _ | |||||
16 个 _ M M _ 18 clf _ _ | |||||
17 欧元 _ NN NN _ 18 nn _ _ | |||||
18 帐户 _ NN NN _ 13 dobj _ _ | |||||
19 后 _ LC LC _ 1 plmod _ _ | |||||
20 , _ PU PU _ 38 punct _ _ | |||||
21 工商 _ NN NN _ 28 nn _ _ | |||||
22 银行 _ NN NN _ 28 conj _ _ |
@@ -1,24 +1,27 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.io.dataset_loader import Conll2003Loader | |||||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \ | |||||
ZhConllPOSReader, ConllxDataLoader | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
def test_case_1(self): | |||||
''' | |||||
def test_Conll2003Loader(self): | |||||
""" | |||||
Test the the loader of Conll2003 dataset | Test the the loader of Conll2003 dataset | ||||
''' | |||||
""" | |||||
dataset_path = "test/data_for_tests/conll_2003_example.txt" | dataset_path = "test/data_for_tests/conll_2003_example.txt" | ||||
loader = Conll2003Loader() | loader = Conll2003Loader() | ||||
dataset_2003 = loader.load(dataset_path) | dataset_2003 = loader.load(dataset_path) | ||||
for item in dataset_2003: | |||||
len0 = len(item["label0_list"]) | |||||
len1 = len(item["label1_list"]) | |||||
len2 = len(item["label2_list"]) | |||||
lentoken = len(item["token_list"]) | |||||
self.assertNotEqual(len0, 0) | |||||
self.assertEqual(len0, len1) | |||||
self.assertEqual(len1, len2) | |||||
def test_PeopleDailyCorpusLoader(self): | |||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | |||||
def test_ConllCWSReader(self): | |||||
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt") | |||||
def test_ZhConllPOSReader(self): | |||||
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx") | |||||
def test_ConllxDataLoader(self): | |||||
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx") |
@@ -0,0 +1,21 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.models.bert import BertModel | |||||
class TestBert(unittest.TestCase): | |||||
def test_bert_1(self): | |||||
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese") | |||||
model = BertModel(vocab_size=32000, hidden_size=768, | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||||
for layer in all_encoder_layers: | |||||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |
@@ -1,8 +1,8 @@ | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||||
import fastNLP | |||||
import unittest | import unittest | ||||
import fastNLP | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||||
data_file = """ | data_file = """ | ||||
1 The _ DET DT _ 3 det _ _ | 1 The _ DET DT _ 3 det _ _ | ||||
2 new _ ADJ JJ _ 3 amod _ _ | 2 new _ ADJ JJ _ 3 amod _ _ | ||||
@@ -41,6 +41,7 @@ data_file = """ | |||||
""" | """ | ||||
def init_data(): | def init_data(): | ||||
ds = fastNLP.DataSet() | ds = fastNLP.DataSet() | ||||
v = {'word_seq': fastNLP.Vocabulary(), | v = {'word_seq': fastNLP.Vocabulary(), | ||||
@@ -60,28 +61,31 @@ def init_data(): | |||||
data.append(line) | data.append(line) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: ['<st>']+list(x[name]), new_field_name=name) | |||||
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | |||||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | ds.apply(lambda x: v[name].add_word_lst(x[name])) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ||||
ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') | |||||
ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | |||||
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ||||
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ||||
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ||||
return ds, v['word_seq'], v['pos_seq'], v['label_true'] | return ds, v['word_seq'], v['pos_seq'], v['label_true'] | ||||
class TestBiaffineParser(unittest.TestCase): | class TestBiaffineParser(unittest.TestCase): | ||||
def test_train(self): | def test_train(self): | ||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | ||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3), use_var_lstm=True) | |||||
num_label=len(v3), encoder='var-lstm') | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
batch_size=1, validate_every=10, | |||||
n_epochs=10, use_cuda=False, use_tqdm=False) | n_epochs=10, use_cuda=False, use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | |||||
unittest.main() |
@@ -1,91 +0,0 @@ | |||||
import unittest | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Tester | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.models import CNNText | |||||
class TestTutorial(unittest.TestCase): | |||||
def test_tutorial(self): | |||||
# 从csv读取数据到DataSet | |||||
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||||
# label转int | |||||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||||
# 使用空格分割句子 | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset.apply(split_sent, new_field_name='words') | |||||
# 增加长度信息 | |||||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
# DataSet.drop(func)筛除数据 | |||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
print(len(dataset)) | |||||
# 设置DataSet中,哪些field要转为tensor | |||||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||||
dataset.set_target("label") | |||||
# set input,模型forward时使用 | |||||
dataset.set_input("words") | |||||
# 分出测试集、训练集 | |||||
test_data, train_data = dataset.split(0.5) | |||||
print(len(test_data)) | |||||
print(len(train_data)) | |||||
# 构建词表, Vocabulary.add(word) | |||||
vocab = Vocabulary(min_freq=2) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||||
vocab.build_vocab() | |||||
# index句子, Vocabulary.to_index(word) | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
print(test_data[0]) | |||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer | |||||
from copy import deepcopy | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||||
train_data.rename_field('label', 'label_seq') | |||||
test_data.rename_field('words', 'word_seq') | |||||
test_data.rename_field('label', 'label_seq') | |||||
# 实例化Trainer,传入模型和数据,进行训练 | |||||
copy_model = deepcopy(model) | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
overfit_trainer.train() | |||||
trainer = Trainer(train_data=train_data, model=model, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
trainer.train() | |||||
print('Train finished!') | |||||
# 使用fastNLP的Tester测试脚本 | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
batch_size=4) | |||||
acc = tester.test() | |||||
print(acc) |
@@ -0,0 +1,432 @@ | |||||
import unittest | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
class TestTutorial(unittest.TestCase): | |||||
def test_fastnlp_10min_tutorial(self): | |||||
# 从csv读取数据到DataSet | |||||
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
print(dataset[-3]) | |||||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||||
# 将所有数字转为小写 | |||||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||||
# label转int | |||||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||||
# 使用空格分割句子 | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
dataset.apply(split_sent, new_field_name='words') | |||||
# 增加长度信息 | |||||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||||
print(len(dataset)) | |||||
print(dataset[0]) | |||||
# DataSet.drop(func)筛除数据 | |||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
print(len(dataset)) | |||||
# 设置DataSet中,哪些field要转为tensor | |||||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||||
dataset.set_target("label") | |||||
# set input,模型forward时使用 | |||||
dataset.set_input("words", "seq_len") | |||||
# 分出测试集、训练集 | |||||
test_data, train_data = dataset.split(0.5) | |||||
print(len(test_data)) | |||||
print(len(train_data)) | |||||
# 构建词表, Vocabulary.add(word) | |||||
vocab = Vocabulary(min_freq=2) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||||
vocab.build_vocab() | |||||
# index句子, Vocabulary.to_index(word) | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||||
print(test_data[0]) | |||||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import RandomSampler | |||||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
for batch_x, batch_y in batch_iterator: | |||||
print("batch_x has: ", batch_x) | |||||
print("batch_y has: ", batch_y) | |||||
break | |||||
from fastNLP.models import CNNText | |||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer | |||||
from copy import deepcopy | |||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||||
train_data.rename_field('label', 'label_seq') | |||||
test_data.rename_field('words', 'word_seq') | |||||
test_data.rename_field('label', 'label_seq') | |||||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||||
# 实例化Trainer,传入模型和数据,进行训练 | |||||
# 先在test_data拟合(确保模型的实现是正确的) | |||||
copy_model = deepcopy(model) | |||||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
loss=loss, | |||||
metrics=metric, | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
overfit_trainer.train() | |||||
# 用train_data训练,在test_data验证 | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
trainer.train() | |||||
print('Train finished!') | |||||
# 调用Tester在test_data上评价效果 | |||||
from fastNLP import Tester | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
batch_size=4) | |||||
acc = tester.test() | |||||
print(acc) | |||||
def test_fastnlp_1min_tutorial(self): | |||||
# tutorials/fastnlp_1min_tutorial.ipynb | |||||
data_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||||
ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') | |||||
print(ds[1]) | |||||
# 将所有数字转为小写 | |||||
ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||||
# label转int | |||||
ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True) | |||||
def split_sent(ins): | |||||
return ins['raw_sentence'].split() | |||||
ds.apply(split_sent, new_field_name='words', is_input=True) | |||||
# 分割训练集/验证集 | |||||
train_data, dev_data = ds.split(0.3) | |||||
print("Train size: ", len(train_data)) | |||||
print("Test size: ", len(dev_data)) | |||||
from fastNLP import Vocabulary | |||||
vocab = Vocabulary(min_freq=2) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||||
# index句子, Vocabulary.to_index(word) | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||||
is_input=True) | |||||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||||
is_input=True) | |||||
from fastNLP.models import CNNText | |||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||||
trainer = Trainer(model=model, | |||||
train_data=train_data, | |||||
dev_data=dev_data, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric() | |||||
) | |||||
trainer.train() | |||||
print('Train finished!') | |||||
def test_fastnlp_advanced_tutorial(self): | |||||
import os | |||||
os.chdir("tutorials/fastnlp_advanced_tutorial") | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Trainer | |||||
from fastNLP import Tester | |||||
# ### Instance | |||||
# Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值 | |||||
# 在初始化Instance的时候可以定义它包含的field,使用"field_name=field_value"的写法 | |||||
# In[2]: | |||||
# 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成 | |||||
instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1) | |||||
instance | |||||
# In[3]: | |||||
data_set = DataSet([instance] * 5) | |||||
data_set.append(instance) | |||||
data_set[-2:] | |||||
# In[4]: | |||||
# 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中 | |||||
instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', | |||||
label='1') | |||||
try: | |||||
data_set.append(instance2) | |||||
except: | |||||
pass | |||||
data_set[-2:] | |||||
# In[5]: | |||||
# 如果某一个field的名字不对,则该instance不能被append到dataset中 | |||||
instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', | |||||
label=1) | |||||
try: | |||||
data_set.append(instance3) | |||||
except: | |||||
print('cannot append instance') | |||||
pass | |||||
data_set[-2:] | |||||
# In[6]: | |||||
# 除了文本以外,还可以将tensor作为其中一个field的value | |||||
import torch | |||||
tensor_ins = Instance(image=torch.randn(5, 5), label=0) | |||||
ds = DataSet() | |||||
ds.append(tensor_ins) | |||||
ds | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
# 从csv读取数据到DataSet | |||||
# 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取 | |||||
dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\t') | |||||
# 查看DataSet的大小 | |||||
len(dataset) | |||||
# In[8]: | |||||
# 使用数字索引[k],获取第k个样本 | |||||
dataset[0] | |||||
# In[9]: | |||||
# 获取的样本是一个Instance | |||||
type(dataset[0]) | |||||
# In[10]: | |||||
# 使用数字索引[a: b],获取第a到第b个样本 | |||||
dataset[0: 3] | |||||
# In[11]: | |||||
# 索引也可以是负数 | |||||
dataset[-1] | |||||
data_path = ['premise', 'hypothesis', 'label'] | |||||
# 读入文件 | |||||
with open(data_path[0]) as f: | |||||
premise = f.readlines() | |||||
with open(data_path[1]) as f: | |||||
hypothesis = f.readlines() | |||||
with open(data_path[2]) as f: | |||||
label = f.readlines() | |||||
assert len(premise) == len(hypothesis) and len(hypothesis) == len(label) | |||||
# 组织DataSet | |||||
data_set = DataSet() | |||||
for p, h, l in zip(premise, hypothesis, label): | |||||
p = p.strip() # 将行末空格去除 | |||||
h = h.strip() # 将行末空格去除 | |||||
data_set.append(Instance(premise=p, hypothesis=h, truth=l)) | |||||
data_set[0] | |||||
# ### DataSet的其他操作 | |||||
# 在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply() | |||||
# In[13]: | |||||
# 将premise域的所有文本转成小写 | |||||
data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise') | |||||
data_set[-2:] | |||||
# In[14]: | |||||
# label转int | |||||
data_set.apply(lambda x: int(x['truth']), new_field_name='truth') | |||||
data_set[-2:] | |||||
# In[15]: | |||||
# 使用空格分割句子 | |||||
def split_sent(ins): | |||||
return ins['premise'].split() | |||||
data_set.apply(split_sent, new_field_name='premise') | |||||
data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis') | |||||
data_set[-2:] | |||||
# In[16]: | |||||
# 筛选数据 | |||||
origin_data_set_len = len(data_set) | |||||
data_set.drop(lambda x: len(x['premise']) <= 6) | |||||
origin_data_set_len, len(data_set) | |||||
# In[17]: | |||||
# 增加长度信息 | |||||
data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len') | |||||
data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len') | |||||
data_set[-1] | |||||
# In[18]: | |||||
# 设定特征域、标签域 | |||||
data_set.set_input("premise", "premise_len", "hypothesis", "hypothesis_len") | |||||
data_set.set_target("truth") | |||||
# In[19]: | |||||
# 重命名field | |||||
data_set.rename_field('truth', 'label') | |||||
data_set[-1] | |||||
# In[20]: | |||||
# 切分训练、验证集、测试集 | |||||
train_data, vad_data = data_set.split(0.5) | |||||
dev_data, test_data = vad_data.split(0.4) | |||||
len(train_data), len(dev_data), len(test_data) | |||||
# In[21]: | |||||
# 深拷贝一个数据集 | |||||
import copy | |||||
train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data) | |||||
del copy | |||||
# 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,'<unk>'表示未知词语,'<pad>'表示padding词语 | |||||
# Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='<unk>', padding='<pad>' | |||||
vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>') | |||||
# 构建词表 | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['premise']]) | |||||
train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']]) | |||||
vocab.build_vocab() | |||||
# In[23]: | |||||
# 根据词表index句子 | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||||
train_data[-1], dev_data[-1], test_data[-1] | |||||
# 读入vocab文件 | |||||
with open('vocab.txt') as f: | |||||
lines = f.readlines() | |||||
vocabs = [] | |||||
for line in lines: | |||||
vocabs.append(line.strip()) | |||||
# 实例化Vocabulary | |||||
vocab_bert = Vocabulary(unknown=None, padding=None) | |||||
# 将vocabs列表加入Vocabulary | |||||
vocab_bert.add_word_lst(vocabs) | |||||
# 构建词表 | |||||
vocab_bert.build_vocab() | |||||
# 更新unknown与padding的token文本 | |||||
vocab_bert.unknown = '[UNK]' | |||||
vocab_bert.padding = '[PAD]' | |||||
# In[25]: | |||||
# 根据词表index句子 | |||||
train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||||
train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], | |||||
new_field_name='hypothesis') | |||||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||||
train_data_2[-1], dev_data_2[-1] | |||||
# step 1:加载模型参数(非必选) | |||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||||
args = ConfigSection() | |||||
ConfigLoader().load_config("./data/config", {"esim_model": args}) | |||||
args["vocab_size"] = len(vocab) | |||||
args.data | |||||
# In[27]: | |||||
# step 2:加载ESIM模型 | |||||
from fastNLP.models import ESIM | |||||
model = ESIM(**args.data) | |||||
model | |||||
# In[28]: | |||||
# 另一个例子:加载CNN文本分类模型 | |||||
from fastNLP.models import CNNText | |||||
cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
cnn_text_model | |||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP import Adam | |||||
from fastNLP import AccuracyMetric | |||||
trainer = Trainer( | |||||
train_data=train_data, | |||||
model=model, | |||||
loss=CrossEntropyLoss(pred='pred', target='label'), | |||||
metrics=AccuracyMetric(), | |||||
n_epochs=3, | |||||
batch_size=16, | |||||
print_every=-1, | |||||
validate_every=-1, | |||||
dev_data=dev_data, | |||||
use_cuda=False, | |||||
optimizer=Adam(lr=1e-3, weight_decay=0), | |||||
check_code_level=-1, | |||||
metric_key='acc', | |||||
use_tqdm=False, | |||||
) | |||||
trainer.train() | |||||
tester = Tester( | |||||
data=test_data, | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=args["batch_size"], | |||||
) | |||||
tester.test() | |||||
os.chdir("../..") |
@@ -0,0 +1,370 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||||
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"DataSet({'raw_sent': this is a bad idea . type=str,\n", | |||||
"'label': 0 type=int,\n", | |||||
"'word_str_lst': ['this', 'is', 'a', 'bad', 'idea', '.'] type=list,\n", | |||||
"'words': [4, 2, 5, 6, 7, 3] type=list},\n", | |||||
"{'raw_sent': it is great . type=str,\n", | |||||
"'label': 1 type=int,\n", | |||||
"'word_str_lst': ['it', 'is', 'great', '.'] type=list,\n", | |||||
"'words': [8, 2, 9, 3] type=list})" | |||||
] | |||||
}, | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 假设有以下的DataSet, 这里只是为了举例所以只选择了两个sample\n", | |||||
"import sys\n", | |||||
"import os\n", | |||||
"sys.path.append('/Users/yh/Desktop/fastNLP/fastNLP')\n", | |||||
"\n", | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"dataset = DataSet()\n", | |||||
"dataset.append(Instance(raw_sent='This is a bad idea .', label=0))\n", | |||||
"dataset.append(Instance(raw_sent='It is great .', label=1))\n", | |||||
"\n", | |||||
"# 按照fastNLP_10min_tutorial.ipynb的步骤,对数据进行一些处理。这里为了演示padding操作,把field的名称做了一些改变\n", | |||||
"dataset.apply(lambda x:x['raw_sent'].lower(), new_field_name='raw_sent')\n", | |||||
"dataset.apply(lambda x:x['raw_sent'].split(), new_field_name='word_str_lst')\n", | |||||
"\n", | |||||
"# 建立Vocabulary\n", | |||||
"word_vocab = Vocabulary()\n", | |||||
"dataset.apply(lambda x:word_vocab.update(x['word_str_lst']))\n", | |||||
"dataset.apply(lambda x:[word_vocab.to_index(word) for word in x['word_str_lst']], new_field_name='words')\n", | |||||
"\n", | |||||
"# 检查以下是否得到我们想要的结果了\n", | |||||
"dataset[:2]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'word_str_lst': array([list(['this', 'is', 'a', 'bad', 'idea', '.']),\n", | |||||
" list(['it', 'is', 'great', '.'])], dtype=object), 'words': tensor([[4, 2, 5, 6, 7, 3],\n", | |||||
" [8, 2, 9, 3, 0, 0]])}\n", | |||||
"batch_y has: {'label': tensor([0, 1])}\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"'\"\\n结果中\\n Batch会对元素类型(元素即最内层的数据,raw_sent为str,word_str_lst为str,words为int, label为int)为int或者float的数据进行默认\\n padding,而非int或float的则不进行padding。但若每个Instance中该field为二维数据,也不进行padding。因为二维数据的padding涉及到\\n 两个维度的padding,不容易自动判断padding的形式。\\n'" | |||||
] | |||||
}, | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 将field设置为input或者target\n", | |||||
"dataset.set_input('word_str_lst')\n", | |||||
"dataset.set_input('words')\n", | |||||
"dataset.set_target('label')\n", | |||||
"\n", | |||||
"# 使用Batch取出batch数据\n", | |||||
"from fastNLP.core.batch import Batch\n", | |||||
"from fastNLP.core.sampler import RandomSampler\n", | |||||
"\n", | |||||
"batch_iterator = Batch(dataset=dataset, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
"\"\"\"\"\n", | |||||
"结果中\n", | |||||
" Batch会对元素类型(元素即最内层的数据,raw_sent为str,word_str_lst为str,words为int, label为int)为int或者float的数据进行默认\n", | |||||
" padding,而非int或float的则不进行padding。但若每个Instance中该field为二维数据,也不进行padding。因为二维数据的padding涉及到\n", | |||||
" 两个维度的padding,不容易自动判断padding的形式。\n", | |||||
"\"\"\"" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'word_str_lst': array([list(['it', 'is', 'great', '.']),\n", | |||||
" list(['this', 'is', 'a', 'bad', 'idea', '.'])], dtype=object), 'words': tensor([[ 8, 2, 9, 3, -100, -100],\n", | |||||
" [ 4, 2, 5, 6, 7, 3]])}\n", | |||||
"batch_y has: {'label': tensor([1, 0])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 所有的pad_val都默认为0,如果需要修改某一个field的默认pad值,可以通过DataSet.set_pad_val(field_name, pad_val)进行修改\n", | |||||
"# 若需要将word的padding修改为-100\n", | |||||
"dataset.set_pad_val('words', pad_val=-100)\n", | |||||
"batch_iterator = Batch(dataset=dataset, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
"# pad的值修改为-100了" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"DataSet({'raw_sent': this is a bad idea . type=str,\n", | |||||
"'label': 0 type=int,\n", | |||||
"'word_str_lst': ['this', 'is', 'a', 'bad', 'idea', '.'] type=list,\n", | |||||
"'words': [4, 2, 5, 6, 7, 3] type=list,\n", | |||||
"'char_str_lst': [['t', 'h', 'i', 's'], ['i', 's'], ['a'], ['b', 'a', 'd'], ['i', 'd', 'e', 'a'], ['.']] type=list,\n", | |||||
"'chars': [[4, 9, 2, 5], [2, 5], [3], [10, 3, 6], [2, 6, 7, 3], [8]] type=list},\n", | |||||
"{'raw_sent': it is great . type=str,\n", | |||||
"'label': 1 type=int,\n", | |||||
"'word_str_lst': ['it', 'is', 'great', '.'] type=list,\n", | |||||
"'words': [8, 2, 9, 3] type=list,\n", | |||||
"'char_str_lst': [['i', 't'], ['i', 's'], ['g', 'r', 'e', 'a', 't'], ['.']] type=list,\n", | |||||
"'chars': [[2, 4], [2, 5], [11, 12, 7, 3, 4], [8]] type=list})" | |||||
] | |||||
}, | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 若需要使用二维padding或指定padding方式,可以通过设置该field的padder实现,下面以英文的character padding为例。在某些场景下,可能想要\n", | |||||
"# 使用英文word的character作为特征,character的padding为二维padding,fastNLP默认只会进行一维padding。\n", | |||||
"\n", | |||||
"dataset.apply(lambda x: [[c for c in word] for word in x['word_str_lst']], new_field_name='char_str_lst')\n", | |||||
"char_vocab = Vocabulary()\n", | |||||
"dataset.apply(lambda x:[char_vocab.update(chars) for chars in x['char_str_lst']])\n", | |||||
"dataset.apply(lambda x:[[char_vocab.to_index(c) for c in chars] for chars in x['char_str_lst']],new_field_name='chars')\n", | |||||
"dataset[:2]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'word_str_lst': array([list(['this', 'is', 'a', 'bad', 'idea', '.']),\n", | |||||
" list(['it', 'is', 'great', '.'])], dtype=object), 'words': tensor([[ 4, 2, 5, 6, 7, 3],\n", | |||||
" [ 8, 2, 9, 3, -100, -100]]), 'chars': array([list([[4, 9, 2, 5], [2, 5], [3], [10, 3, 6], [2, 6, 7, 3], [8]]),\n", | |||||
" list([[2, 4], [2, 5], [11, 12, 7, 3, 4], [8]])], dtype=object)}\n", | |||||
"batch_y has: {'label': tensor([0, 1])}\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"'\\n 其它field与之前的是相同的。chars因为存在两个维度需要padding,不能自动决定padding方式,所以直接输出了原始形式。\\n'" | |||||
] | |||||
}, | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 如果不针对二维的character指定padding方法\n", | |||||
"dataset.set_input('chars')\n", | |||||
"batch_iterator = Batch(dataset=dataset, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
" \n", | |||||
"\"\"\"\n", | |||||
" 其它field与之前的是相同的。chars因为存在两个维度需要padding,不能自动决定padding方式,所以直接输出了原始形式。\n", | |||||
"\"\"\"" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'word_str_lst': array([list(['this', 'is', 'a', 'bad', 'idea', '.']),\n", | |||||
" list(['it', 'is', 'great', '.'])], dtype=object), 'words': tensor([[ 4, 2, 5, 6, 7, 3],\n", | |||||
" [ 8, 2, 9, 3, -100, -100]]), 'chars': tensor([[[ 4, 9, 2, 5],\n", | |||||
" [ 2, 5, 0, 0],\n", | |||||
" [ 3, 0, 0, 0],\n", | |||||
" [10, 3, 6, 0],\n", | |||||
" [ 2, 6, 7, 3],\n", | |||||
" [ 8, 0, 0, 0]],\n", | |||||
"\n", | |||||
" [[ 2, 4, 0, 0],\n", | |||||
" [ 2, 5, 0, 0],\n", | |||||
" [11, 12, 7, 3],\n", | |||||
" [ 8, 0, 0, 0],\n", | |||||
" [ 0, 0, 0, 0],\n", | |||||
" [ 0, 0, 0, 0]]])}\n", | |||||
"batch_y has: {'label': tensor([0, 1])}\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"'\\n chars被正确padding了\\n'" | |||||
] | |||||
}, | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 若要使用二维padding,需要手动设置padding方式\n", | |||||
"from fastNLP.core.fieldarray import EngChar2DPadder\n", | |||||
"dataset.set_padder('chars', EngChar2DPadder())\n", | |||||
"batch_iterator = Batch(dataset=dataset, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
" \n", | |||||
"\"\"\"\n", | |||||
" chars被正确padding了\n", | |||||
"\"\"\"" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'raw_sent': ['this is a bad idea .', 'it is great . '], 'word_str_lst': array([list(['this', 'is', 'a', 'bad', 'idea', '.']),\n", | |||||
" list(['it', 'is', 'great', '.'])], dtype=object), 'words': tensor([[ 4, 2, 5, 6, 7, 3],\n", | |||||
" [ 8, 2, 9, 3, -100, -100]]), 'chars': tensor([[[ 4, 9, 2, 5],\n", | |||||
" [ 2, 5, 0, 0],\n", | |||||
" [ 3, 0, 0, 0],\n", | |||||
" [10, 3, 6, 0],\n", | |||||
" [ 2, 6, 7, 3],\n", | |||||
" [ 8, 0, 0, 0]],\n", | |||||
"\n", | |||||
" [[ 2, 4, 0, 0],\n", | |||||
" [ 2, 5, 0, 0],\n", | |||||
" [11, 12, 7, 3],\n", | |||||
" [ 8, 0, 0, 0],\n", | |||||
" [ 0, 0, 0, 0],\n", | |||||
" [ 0, 0, 0, 0]]])}\n", | |||||
"batch_y has: {'label': tensor([0, 1])}\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"'\\n raw_sent正确输出,对应内容也进行了pad。\\n'" | |||||
] | |||||
}, | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 如果AutoPad与EngChar2DPadder不能满足需要,可以自己实现Padder对象。这里举一个例子,比如需要把raw_sentence pad到一样长\n", | |||||
"from fastNLP.core.fieldarray import PadderBase\n", | |||||
"\n", | |||||
"class PadStr(PadderBase):\n", | |||||
" def __init__(self, pad_val=' '):\n", | |||||
" super().__init__(pad_val=pad_val) #让父类管理pad_val的值,这样可以通过DataSet.set_pad_val()修改到该值\n", | |||||
" \n", | |||||
" def __call__(self, contents, field_name, field_ele_dtype):\n", | |||||
" \"\"\"\n", | |||||
" 如果以上面的例子举例,在raw_sent这个field进行pad时,传入的\n", | |||||
" contents:\n", | |||||
" [\n", | |||||
" 'This is a bad idea .',\n", | |||||
" 'It is great .'\n", | |||||
" ]\n", | |||||
" field_name: 'raw_sent',当前field的名称,主要用于帮助debug。\n", | |||||
" field_ele_dtype: np.str. 这个参数基本都用不上,是该field中内部元素的类型\n", | |||||
" \"\"\"\n", | |||||
" max_len = max([len(str_) for str_ in contents])\n", | |||||
" pad_strs = []\n", | |||||
" for content in contents:\n", | |||||
" pad_strs.append(content + (max_len-len(content))*self.pad_val)\n", | |||||
" return pad_strs\n", | |||||
"\n", | |||||
"dataset.set_input('raw_sent')\n", | |||||
"dataset.set_padder('raw_sent', PadStr())\n", | |||||
"batch_iterator = Batch(dataset=dataset, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
"\n", | |||||
"\"\"\"\n", | |||||
" raw_sent正确输出,对应内容也进行了pad。\n", | |||||
"\"\"\"" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,97 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## fastNLP测试说明\n", | |||||
"### 测试环境\n", | |||||
"fastNLP使用pytest对代码进行单元测试,测试代码在test文件夹下,测试所需数据在test/data_for_tests文件夹下\n", | |||||
"测试的步骤主要分为准备数据,执行测试,比对结果,清除环境四步\n", | |||||
"测试代码以test_xxx.py命名,以DataSet的测试代码为例,测试代码文件名为test_dataset.py" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import os\n", | |||||
"import unittest # 单元测试需要用到unittest\n", | |||||
"\n", | |||||
"from fastNLP.core.dataset import DataSet\n", | |||||
"from fastNLP.core.fieldarray import FieldArray\n", | |||||
"from fastNLP.core.instance import Instance\n", | |||||
"# 在这个单元测试文件中,需要测试DataSet、FieldArray、以及Instance" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"class TestDataSet(unittest.TestCase): # 类名字以Test打头,继承unittest.TestCase\n", | |||||
"\n", | |||||
" def test_init_v1(self): # 测试样例1, 函数名称以test_打头\n", | |||||
" # 该测试样例测试的是DataSet的初始化\n", | |||||
" ins = Instance(x=[1, 2, 3, 4], y=[5, 6]) # 准备数据\n", | |||||
" ds = DataSet([ins] * 40) # 执行测试(调用DataSet的初始化函数)\n", | |||||
" self.assertTrue(\"x\" in ds.field_arrays and \"y\" in ds.field_arrays) # 比对结果:'x'跟'y'都是ds的field\n", | |||||
" self.assertEqual(ds.field_arrays[\"x\"].content, [[1, 2, 3, 4], ] * 40) # 比对结果: field 'x'的内容正确\n", | |||||
" self.assertEqual(ds.field_arrays[\"y\"].content, [[5, 6], ] * 40) # 比对结果: field 'y'的内容正确\n", | |||||
" \n", | |||||
" def test_init_v2(self): # 测试样例2,该样例测试DataSet的另一种初始化方式\n", | |||||
" ds = DataSet({\"x\": [[1, 2, 3, 4]] * 40, \"y\": [[5, 6]] * 40})\n", | |||||
" self.assertTrue(\"x\" in ds.field_arrays and \"y\" in ds.field_arrays)\n", | |||||
" self.assertEqual(ds.field_arrays[\"x\"].content, [[1, 2, 3, 4], ] * 40)\n", | |||||
" self.assertEqual(ds.field_arrays[\"y\"].content, [[5, 6], ] * 40)\n", | |||||
" \n", | |||||
" def test_init_assert(self): # 测试样例3,该样例测试不规范初始化DataSet时是否会报正确错误\n", | |||||
" with self.assertRaises(AssertionError):\n", | |||||
" _ = DataSet({\"x\": [[1, 2, 3, 4]] * 40, \"y\": [[5, 6]] * 100})\n", | |||||
" with self.assertRaises(AssertionError):\n", | |||||
" _ = DataSet([[1, 2, 3, 4]] * 10)\n", | |||||
" with self.assertRaises(ValueError):\n", | |||||
" _ = DataSet(0.00001)\n", | |||||
" \n", | |||||
" def test_contains(self): # 测试样例4,该样例测试DataSet的contains函数,是功能测试\n", | |||||
" ds = DataSet({\"x\": [[1, 2, 3, 4]] * 40, \"y\": [[5, 6]] * 40})\n", | |||||
" self.assertTrue(\"x\" in ds)\n", | |||||
" self.assertTrue(\"y\" in ds)\n", | |||||
" self.assertFalse(\"z\" in ds)\n", | |||||
" \n", | |||||
" # 更多测试样例见test/core/test_dataset.py" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.4" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |