@@ -0,0 +1,5 @@ | |||||
ignore: | |||||
- "reproduction" # ignore folders and all its contents | |||||
- "setup.py" | |||||
- "docs" | |||||
- "tutorials" |
@@ -1,7 +1,8 @@ | |||||
fastNLP上手教程 | |||||
fastNLP 10分钟上手教程 | |||||
=============== | =============== | ||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb | |||||
fastNLP提供方便的数据预处理,训练和测试模型的功能 | fastNLP提供方便的数据预处理,训练和测试模型的功能 | ||||
DataSet & Instance | DataSet & Instance | ||||
@@ -2,6 +2,8 @@ | |||||
FastNLP 1分钟上手教程 | FastNLP 1分钟上手教程 | ||||
===================== | ===================== | ||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb | |||||
step 1 | step 1 | ||||
------ | ------ | ||||
@@ -0,0 +1,5 @@ | |||||
fastNLP 进阶教程 | |||||
=============== | |||||
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb | |||||
@@ -0,0 +1,5 @@ | |||||
fastNLP 开发者指南 | |||||
=============== | |||||
原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md | |||||
@@ -5,6 +5,7 @@ Installation | |||||
.. contents:: | .. contents:: | ||||
:local: | :local: | ||||
Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt . | |||||
Run the following commands to install fastNLP package: | Run the following commands to install fastNLP package: | ||||
@@ -6,4 +6,6 @@ Quickstart | |||||
../tutorials/fastnlp_1_minute_tutorial | ../tutorials/fastnlp_1_minute_tutorial | ||||
../tutorials/fastnlp_10tmin_tutorial | ../tutorials/fastnlp_10tmin_tutorial | ||||
../tutorials/fastnlp_advanced_tutorial | |||||
../tutorials/fastnlp_developer_guide | |||||
@@ -18,26 +18,27 @@ print(cws.predict(text)) | |||||
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | # ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | ||||
``` | ``` | ||||
### 中文分词+词性标注 | |||||
### 词性标注 | |||||
```python | ```python | ||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
# 输入已分词序列 | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import POS | from fastNLP.api import POS | ||||
pos = POS(device='cpu') | pos = POS(device='cpu') | ||||
print(pos.predict(text)) | print(pos.predict(text)) | ||||
# [['编者/NN', '按/P', ':/PU', '7月/NT', '12日/NR', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一/OD', '款高/NN', '科技/NN', '隐形/NN', '无/VE', '人机/NN', '雷电/NN', '之/DEG', '神/NN', '。/PU'], ['这/DT', '款/NN', '飞行/VV', '从/P', '外型/NN', '上/LC', '来/MSP', '看/VV', '酷似/VV', '电影/NN', '中/LC', '的/DEG', '太空/NN', '飞行器/NN', ',/PU', '据/P', '英国/NR', '方面/NN', '介绍/VV', ',/PU', '可以/VV', '实现/VV', '洲际/NN', '远程/NN', '打击/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无/VE', '人机/NN', '到底/AD', '有/VE', '多/CD', '厉害/NN', '?/PU']] | |||||
# [['编者/NN', '按:/NN', '7月/NT', '12日/NT', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一款/NN', '高科技/NN', '隐形/AD', '无人机/VV', '雷电之神/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无人机/VV', '到底/AD', '有/VE', '多/AD', '厉害/VA', '?/PU']] | |||||
``` | ``` | ||||
### 中文分词+词性标注+句法分析 | |||||
### 句法分析 | |||||
```python | ```python | ||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import Parser | from fastNLP.api import Parser | ||||
parser = Parser(device='cpu') | parser = Parser(device='cpu') | ||||
print(parser.predict(text)) | print(parser.predict(text)) | ||||
# [['12/nsubj', '12/prep', '2/punct', '5/nn', '2/pobj', '12/punct', '11/nn', '11/nn', '11/nn', '11/nn', '2/pobj', '0/root', '12/asp', '15/det', '16/nsubj', '21/rcmod', '16/cpm', '21/nummod', '21/nn', '21/nn', '22/top', '12/ccomp', '24/nn', '26/assmod', '24/assm', '22/dobj', '12/punct'], ['2/det', '8/xsubj', '8/mmod', '8/prep', '6/lobj', '4/plmod', '8/prtmod', '0/root', '8/ccomp', '11/lobj', '14/assmod', '11/assm', '14/nn', '9/dobj', '8/punct', '22/prep', '18/nn', '19/nsubj', '16/pccomp', '22/punct', '22/mmod', '8/dep', '25/nn', '25/nn', '22/dobj', '8/punct'], ['4/advmod', '3/det', '4/nsubj', '0/root', '4/dobj', '7/advmod', '4/conj', '9/nummod', '7/dobj', '4/punct']] | |||||
# [['2/nn', '4/nn', '4/nn', '20/tmod', '11/punct', '10/nn', '10/nn', '10/nn', '10/nn', '11/nsubj', '20/dep', '11/asp', '14/det', '15/nsubj', '18/rcmod', '15/cpm', '18/nn', '11/dobj', '20/advmod', '0/root', '20/dobj', '20/punct'], ['4/advmod', '3/det', '8/xsubj', '8/dep', '8/advmod', '8/dep', '8/advmod', '0/root', '8/punct']] | |||||
``` | ``` | ||||
完整样例见`examples.py` | 完整样例见`examples.py` |
@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.utils import load_url | from fastNLP.api.utils import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -17,9 +17,9 @@ from fastNLP.api.processor import IndexerProcessor | |||||
# TODO add pretrain urls | # TODO add pretrain urls | ||||
model_urls = { | model_urls = { | ||||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", | |||||
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl", | |||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", | "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", | ||||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" | |||||
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl" | |||||
} | } | ||||
@@ -90,38 +90,28 @@ class POS(API): | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
# def decode_tags(ins): | |||||
# pred_tags = ins["tag"] | |||||
# chars = ins["words"] | |||||
# words = [] | |||||
# start_idx = 0 | |||||
# for idx, tag in enumerate(pred_tags): | |||||
# if tag[0] == "S": | |||||
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
# start_idx = idx + 1 | |||||
# elif tag[0] == "E": | |||||
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
# start_idx = idx + 1 | |||||
# return words | |||||
# | |||||
# dataset.apply(decode_tags, new_field_name="tag_output") | |||||
def merge_tag(words_list, tags_list): | |||||
rtn = [] | |||||
for words, tags in zip(words_list, tags_list): | |||||
rtn.append([w + "/" + t for w, t in zip(words, tags)]) | |||||
return rtn | |||||
output = dataset.field_arrays["tag"].content | output = dataset.field_arrays["tag"].content | ||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
return output | |||||
return merge_tag(content, output) | |||||
def test(self, file_path): | def test(self, file_path): | ||||
test_data = ConllxDataLoader().load(file_path) | test_data = ConllxDataLoader().load(file_path) | ||||
with open("model_pp_0117.pkl", "rb") as f: | |||||
save_dict = torch.load(f) | |||||
save_dict = self._dict | |||||
tag_vocab = save_dict["tag_vocab"] | tag_vocab = save_dict["tag_vocab"] | ||||
pipeline = save_dict["pipeline"] | pipeline = save_dict["pipeline"] | ||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | ||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | pipeline.pipeline = [index_tag] + pipeline.pipeline | ||||
test_data.rename_field("pos_tags", "tag") | |||||
pipeline(test_data) | pipeline(test_data) | ||||
test_data.set_target("truth") | test_data.set_target("truth") | ||||
prediction = test_data.field_arrays["predict"].content | prediction = test_data.field_arrays["predict"].content | ||||
@@ -235,7 +225,7 @@ class CWS(API): | |||||
rec = eval_res['BMESF1PreRecMetric']['rec'] | rec = eval_res['BMESF1PreRecMetric']['rec'] | ||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | ||||
return f1, pre, rec | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
class Parser(API): | class Parser(API): | ||||
@@ -260,6 +250,7 @@ class Parser(API): | |||||
dataset.add_field('wp', pos_out) | dataset.add_field('wp', pos_out) | ||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | ||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | ||||
dataset.rename_field("words", "raw_words") | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
@@ -269,31 +260,74 @@ class Parser(API): | |||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | # output like: [['2/top', '0/root', '4/nn', '2/dep']] | ||||
return dataset.field_arrays['output'].content | return dataset.field_arrays['output'].content | ||||
def test(self, filepath): | |||||
data = ConllxDataLoader().load(filepath) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
def load_test_file(self, path): | |||||
def get_one(sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
return data_list | |||||
def test(self, filepath): | |||||
data = self.load_test_file(filepath) | |||||
def convert(data): | |||||
BOS = '<BOS>' | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + sample[2] | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(raw_words=word_seq, | |||||
pos=pos_seq, | |||||
gold_heads=heads, | |||||
arc_true=heads, | |||||
tags=head_tags)) | |||||
return dataset | |||||
ds = convert(data) | |||||
pp = self.pipeline | pp = self.pipeline | ||||
for p in pp: | for p in pp: | ||||
if p.field_name == 'word_list': | if p.field_name == 'word_list': | ||||
p.field_name = 'gold_words' | p.field_name = 'gold_words' | ||||
elif p.field_name == 'pos_list': | elif p.field_name == 'pos_list': | ||||
p.field_name = 'gold_pos' | p.field_name = 'gold_pos' | ||||
# ds.rename_field("words", "raw_words") | |||||
# ds.rename_field("tag", "pos") | |||||
pp(ds) | pp(ds) | ||||
head_cor, label_cor, total = 0, 0, 0 | head_cor, label_cor, total = 0, 0, 0 | ||||
for ins in ds: | for ins in ds: | ||||
head_gold = ins['gold_heads'] | head_gold = ins['gold_heads'] | ||||
head_pred = ins['heads'] | |||||
head_pred = ins['arc_pred'] | |||||
length = len(head_gold) | length = len(head_gold) | ||||
total += length | total += length | ||||
for i in range(length): | for i in range(length): | ||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | head_cor += 1 if head_pred[i] == head_gold[i] else 0 | ||||
uas = head_cor / total | uas = head_cor / total | ||||
print('uas:{:.2f}'.format(uas)) | |||||
# print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | for p in pp: | ||||
if p.field_name == 'gold_words': | if p.field_name == 'gold_words': | ||||
@@ -301,7 +335,7 @@ class Parser(API): | |||||
elif p.field_name == 'gold_pos': | elif p.field_name == 'gold_pos': | ||||
p.field_name = 'pos_list' | p.field_name = 'pos_list' | ||||
return uas | |||||
return {"USA": round(uas, 5)} | |||||
class Analyzer: | class Analyzer: | ||||
@@ -15,19 +15,42 @@ def chinese_word_segmentation(): | |||||
print(cws.predict(text)) | print(cws.predict(text)) | ||||
def chinese_word_segmentation_test(): | |||||
cws = CWS(device='cpu') | |||||
print(cws.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def pos_tagging(): | def pos_tagging(): | ||||
# 输入已分词序列 | # 输入已分词序列 | ||||
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。'] | |||||
text = [text[0].split()] | |||||
print(text) | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
pos = POS(device='cpu') | pos = POS(device='cpu') | ||||
print(pos.predict(text)) | print(pos.predict(text)) | ||||
def pos_tagging_test(): | |||||
pos = POS(device='cpu') | |||||
print(pos.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def syntactic_parsing(): | def syntactic_parsing(): | ||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
parser = Parser(device='cpu') | parser = Parser(device='cpu') | ||||
print(parser.predict(text)) | print(parser.predict(text)) | ||||
def syntactic_parsing_test(): | |||||
parser = Parser(device='cpu') | |||||
print(parser.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
pos_tagging() | |||||
# chinese_word_segmentation() | |||||
# chinese_word_segmentation_test() | |||||
# pos_tagging() | |||||
# pos_tagging_test() | |||||
syntactic_parsing() | |||||
# syntactic_parsing_test() |
@@ -102,6 +102,7 @@ class PreAppendProcessor(Processor): | |||||
[data] + instance[field_name] | [data] + instance[field_name] | ||||
""" | """ | ||||
def __init__(self, data, field_name, new_added_field_name=None): | def __init__(self, data, field_name, new_added_field_name=None): | ||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.data = data | self.data = data | ||||
@@ -116,6 +117,7 @@ class SliceProcessor(Processor): | |||||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | 从某个field中只取部分内容。等价于instance[field_name][start:end:step] | ||||
""" | """ | ||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | def __init__(self, start, end, step, field_name, new_added_field_name=None): | ||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | super(SliceProcessor, self).__init__(field_name, new_added_field_name) | ||||
for o in (start, end, step): | for o in (start, end, step): | ||||
@@ -132,6 +134,7 @@ class Num2TagProcessor(Processor): | |||||
将一句话中的数字转换为某个tag。 | 将一句话中的数字转换为某个tag。 | ||||
""" | """ | ||||
def __init__(self, tag, field_name, new_added_field_name=None): | def __init__(self, tag, field_name, new_added_field_name=None): | ||||
""" | """ | ||||
@@ -163,6 +166,7 @@ class IndexerProcessor(Processor): | |||||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | ||||
['我', '是', xxx] | ['我', '是', xxx] | ||||
""" | """ | ||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -215,6 +219,7 @@ class SeqLenProcessor(Processor): | |||||
根据某个field新增一个sequence length的field。取该field的第一维 | 根据某个field新增一个sequence length的field。取该field的第一维 | ||||
""" | """ | ||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | ||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.is_input = is_input | self.is_input = is_input | ||||
@@ -229,6 +234,7 @@ class SeqLenProcessor(Processor): | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
@@ -292,6 +298,7 @@ class Index2WordProcessor(Processor): | |||||
将DataSet中某个为index的field根据vocab转换为str | 将DataSet中某个为index的field根据vocab转换为str | ||||
""" | """ | ||||
def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
@@ -303,7 +310,6 @@ class Index2WordProcessor(Processor): | |||||
class SetTargetProcessor(Processor): | class SetTargetProcessor(Processor): | ||||
# TODO; remove it. | |||||
def __init__(self, *fields, flag=True): | def __init__(self, *fields, flag=True): | ||||
super(SetTargetProcessor, self).__init__(None, None) | super(SetTargetProcessor, self).__init__(None, None) | ||||
self.fields = fields | self.fields = fields | ||||
@@ -313,6 +319,7 @@ class SetTargetProcessor(Processor): | |||||
dataset.set_target(*self.fields, flag=self.flag) | dataset.set_target(*self.fields, flag=self.flag) | ||||
return dataset | return dataset | ||||
class SetInputProcessor(Processor): | class SetInputProcessor(Processor): | ||||
def __init__(self, *fields, flag=True): | def __init__(self, *fields, flag=True): | ||||
super(SetInputProcessor, self).__init__(None, None) | super(SetInputProcessor, self).__init__(None, None) | ||||
@@ -92,6 +92,10 @@ class DataSet(object): | |||||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target) | is_input=field.is_input, is_target=field.is_target) | ||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | |||||
if idx not in self: | |||||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||||
return self.field_arrays[idx] | |||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -11,18 +11,24 @@ class BaseLoader(object): | |||||
@staticmethod | @staticmethod | ||||
def load_lines(data_path): | def load_lines(data_path): | ||||
"""按行读取,舍弃每行两侧空白字符,返回list of str | |||||
""" | |||||
with open(data_path, "r", encoding="utf=8") as f: | with open(data_path, "r", encoding="utf=8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
@classmethod | @classmethod | ||||
def load(cls, data_path): | def load(cls, data_path): | ||||
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||||
""" | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
text = f.readlines() | text = f.readlines() | ||||
return [[word for word in sent.strip()] for sent in text] | return [[word for word in sent.strip()] for sent in text] | ||||
@classmethod | @classmethod | ||||
def load_with_cache(cls, data_path, cache_path): | def load_with_cache(cls, data_path, cache_path): | ||||
"""缓存版的load | |||||
""" | |||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | ||||
with open(cache_path, 'rb') as f: | with open(cache_path, 'rb') as f: | ||||
return pickle.load(f) | return pickle.load(f) | ||||
@@ -11,7 +11,6 @@ class ConfigLoader(BaseLoader): | |||||
:param str data_path: path to the config | :param str data_path: path to the config | ||||
""" | """ | ||||
def __init__(self, data_path=None): | def __init__(self, data_path=None): | ||||
super(ConfigLoader, self).__init__() | super(ConfigLoader, self).__init__() | ||||
if data_path is not None: | if data_path is not None: | ||||
@@ -30,7 +29,7 @@ class ConfigLoader(BaseLoader): | |||||
Example:: | Example:: | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
""" | """ | ||||
assert isinstance(sections, dict) | assert isinstance(sections, dict) | ||||
@@ -202,8 +201,6 @@ class ConfigSaver(object): | |||||
continue | continue | ||||
if '=' not in line: | if '=' not in line: | ||||
# log = create_logger(__name__, './config_saver.log') | |||||
# log.error("can NOT load config file [%s]" % self.file_path) | |||||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | ||||
key = line.split('=', maxsplit=1)[0].strip() | key = line.split('=', maxsplit=1)[0].strip() | ||||
@@ -263,10 +260,6 @@ class ConfigSaver(object): | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if section_file[k] != section[k]: | if section_file[k] != section[k]: | ||||
# logger = create_logger(__name__, "./config_loader.log") | |||||
# logger.warning("section [%s] in config file [%s] has been changed" % ( | |||||
# section_name, self.file_path | |||||
# )) | |||||
change_file = True | change_file = True | ||||
break | break | ||||
if not change_file: | if not change_file: | ||||
@@ -126,8 +126,8 @@ class RawDataSetLoader(DataSetLoader): | |||||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | ||||
class POSDataSetLoader(DataSetLoader): | |||||
"""Dataset Loader for a POS Tag dataset. | |||||
class DummyPOSReader(DataSetLoader): | |||||
"""A simple reader for a dummy POS tagging dataset. | |||||
In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second | ||||
Col is the label. Different sentence are divided by an empty line. | Col is the label. Different sentence are divided by an empty line. | ||||
@@ -146,7 +146,7 @@ class POSDataSetLoader(DataSetLoader): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(POSDataSetLoader, self).__init__() | |||||
super(DummyPOSReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
""" | """ | ||||
@@ -194,16 +194,14 @@ class POSDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos') | |||||
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') | |||||
class TokenizeDataSetLoader(DataSetLoader): | |||||
class DummyCWSReader(DataSetLoader): | |||||
"""Load pku dataset for Chinese word segmentation. | |||||
""" | """ | ||||
Data set loader for tokenization data sets | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(TokenizeDataSetLoader, self).__init__() | |||||
super(DummyCWSReader, self).__init__() | |||||
def load(self, data_path, max_seq_len=32): | def load(self, data_path, max_seq_len=32): | ||||
"""Load pku dataset for Chinese word segmentation. | """Load pku dataset for Chinese word segmentation. | ||||
@@ -256,11 +254,11 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
return convert_seq2seq_dataset(data) | return convert_seq2seq_dataset(data) | ||||
class ClassDataSetLoader(DataSetLoader): | |||||
class DummyClassificationReader(DataSetLoader): | |||||
"""Loader for a dummy classification data set""" | """Loader for a dummy classification data set""" | ||||
def __init__(self): | def __init__(self): | ||||
super(ClassDataSetLoader, self).__init__() | |||||
super(DummyClassificationReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
assert os.path.exists(data_path) | assert os.path.exists(data_path) | ||||
@@ -271,7 +269,7 @@ class ClassDataSetLoader(DataSetLoader): | |||||
@staticmethod | @staticmethod | ||||
def parse(lines): | def parse(lines): | ||||
""" | |||||
"""每行第一个token是标签,其余是字/词;由空格分隔。 | |||||
:param lines: lines from dataset | :param lines: lines from dataset | ||||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | :return: list(list(list())): the three level of lists are words, sentence, and dataset | ||||
@@ -327,16 +325,11 @@ class ConllLoader(DataSetLoader): | |||||
pass | pass | ||||
class LMDataSetLoader(DataSetLoader): | |||||
"""Language Model Dataset Loader | |||||
This loader produces data for language model training in a supervised way. | |||||
That means it has X and Y. | |||||
class DummyLMReader(DataSetLoader): | |||||
"""A Dummy Language Model Dataset Reader | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(LMDataSetLoader, self).__init__() | |||||
super(DummyLMReader, self).__init__() | |||||
def load(self, data_path): | def load(self, data_path): | ||||
if not os.path.exists(data_path): | if not os.path.exists(data_path): | ||||
@@ -364,19 +357,25 @@ class LMDataSetLoader(DataSetLoader): | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
"""人民日报数据集 | |||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
super(PeopleDailyCorpusLoader, self).__init__() | super(PeopleDailyCorpusLoader, self).__init__() | ||||
self.pos = True | |||||
self.ner = True | |||||
def load(self, data_path): | |||||
def load(self, data_path, pos=True, ner=True): | |||||
""" | |||||
:param str data_path: 数据路径 | |||||
:param bool pos: 是否使用词性标签 | |||||
:param bool ner: 是否使用命名实体标签 | |||||
:return: a DataSet object | |||||
""" | |||||
self.pos, self.ner = pos, ner | |||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
sents = f.readlines() | sents = f.readlines() | ||||
pos_tag_examples = [] | |||||
ner_examples = [] | |||||
examples = [] | |||||
for sent in sents: | for sent in sents: | ||||
if len(sent) <= 2: | if len(sent) <= 2: | ||||
continue | continue | ||||
@@ -410,40 +409,44 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
sent_ner.append(ner_tag) | sent_ner.append(ner_tag) | ||||
sent_pos_tag.append(pos) | sent_pos_tag.append(pos) | ||||
sent_words.append(token) | sent_words.append(token) | ||||
pos_tag_examples.append([sent_words, sent_pos_tag]) | |||||
ner_examples.append([sent_words, sent_ner]) | |||||
# List[List[List[str], List[str]]] | |||||
# ner_examples not used | |||||
return self.convert(pos_tag_examples) | |||||
example = [sent_words] | |||||
if self.pos is True: | |||||
example.append(sent_pos_tag) | |||||
if self.ner is True: | |||||
example.append(sent_ner) | |||||
examples.append(example) | |||||
return self.convert(examples) | |||||
def convert(self, data): | def convert(self, data): | ||||
data_set = DataSet() | data_set = DataSet() | ||||
for item in data: | for item in data: | ||||
sent_words, sent_pos_tag = item[0], item[1] | |||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag)) | |||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len") | |||||
data_set.set_target("tags") | |||||
data_set.set_input("sent_words") | |||||
data_set.set_input("seq_len") | |||||
sent_words = item[0] | |||||
if self.pos is True and self.ner is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2]) | |||||
elif self.pos is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||||
elif self.ner is True: | |||||
instance = Instance(words=sent_words, ner=item[1]) | |||||
else: | |||||
instance = Instance(words=sent_words) | |||||
data_set.append(instance) | |||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
return data_set | return data_set | ||||
class Conll2003Loader(DataSetLoader): | class Conll2003Loader(DataSetLoader): | ||||
"""Self-defined loader of conll2003 dataset | |||||
"""Loader for conll2003 dataset | |||||
More information about the given dataset cound be found on | More information about the given dataset cound be found on | ||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(Conll2003Loader, self).__init__() | super(Conll2003Loader, self).__init__() | ||||
def load(self, dataset_path): | def load(self, dataset_path): | ||||
with open(dataset_path, "r", encoding="utf-8") as f: | with open(dataset_path, "r", encoding="utf-8") as f: | ||||
lines = f.readlines() | lines = f.readlines() | ||||
##Parse the dataset line by line | |||||
parsed_data = [] | parsed_data = [] | ||||
sentence = [] | sentence = [] | ||||
tokens = [] | tokens = [] | ||||
@@ -470,21 +473,20 @@ class Conll2003Loader(DataSetLoader): | |||||
lambda labels: labels[1], sample[1])) | lambda labels: labels[1], sample[1])) | ||||
label2_list = list(map( | label2_list = list(map( | ||||
lambda labels: labels[2], sample[1])) | lambda labels: labels[2], sample[1])) | ||||
dataset.append(Instance(token_list=sample[0], | |||||
label0_list=label0_list, | |||||
label1_list=label1_list, | |||||
label2_list=label2_list)) | |||||
dataset.append(Instance(tokens=sample[0], | |||||
pos=label0_list, | |||||
chucks=label1_list, | |||||
ner=label2_list)) | |||||
return dataset | return dataset | ||||
class SNLIDataSetLoader(DataSetLoader): | |||||
class SNLIDataSetReader(DataSetLoader): | |||||
"""A data set loader for SNLI data set. | """A data set loader for SNLI data set. | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(SNLIDataSetLoader, self).__init__() | |||||
super(SNLIDataSetReader, self).__init__() | |||||
def load(self, path_list): | def load(self, path_list): | ||||
""" | """ | ||||
@@ -553,6 +555,8 @@ class ConllCWSReader(object): | |||||
""" | """ | ||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | 返回的DataSet只包含raw_sentence这个field,内容为str。 | ||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | ||||
:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | 1 编者按 编者按 NN O 11 nmod:topic | ||||
2 : : PU O 11 punct | 2 : : PU O 11 punct | ||||
3 7月 7月 NT DATE 4 compound:nn | 3 7月 7月 NT DATE 4 compound:nn | ||||
@@ -564,6 +568,7 @@ class ConllCWSReader(object): | |||||
3 飞行 飞行 NN O 8 nsubj | 3 飞行 飞行 NN O 8 nsubj | ||||
4 从 从 P O 5 case | 4 从 从 P O 5 case | ||||
5 外型 外型 NN O 8 nmod:prep | 5 外型 外型 NN O 8 nmod:prep | ||||
""" | """ | ||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
@@ -575,7 +580,7 @@ class ConllCWSReader(object): | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
sample.append(line.split('\t')) | |||||
sample.append(line.strip().split()) | |||||
if len(sample) > 0: | if len(sample) > 0: | ||||
datalist.append(sample) | datalist.append(sample) | ||||
@@ -592,7 +597,6 @@ class ConllCWSReader(object): | |||||
sents = [line] | sents = [line] | ||||
for raw_sentence in sents: | for raw_sentence in sents: | ||||
ds.append(Instance(raw_sentence=raw_sentence)) | ds.append(Instance(raw_sentence=raw_sentence)) | ||||
return ds | return ds | ||||
def get_char_lst(self, sample): | def get_char_lst(self, sample): | ||||
@@ -607,70 +611,22 @@ class ConllCWSReader(object): | |||||
return text | return text | ||||
class POSCWSReader(DataSetLoader): | |||||
""" | |||||
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限. | |||||
迈 N | |||||
向 N | |||||
充 N | |||||
... | |||||
泽 I-PER | |||||
民 I-PER | |||||
( N | |||||
一 N | |||||
九 N | |||||
... | |||||
:param filepath: | |||||
:return: | |||||
""" | |||||
def __init__(self, in_word_splitter=None): | |||||
super().__init__() | |||||
self.in_word_splitter = in_word_splitter | |||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||||
if in_word_splitter is None: | |||||
in_word_splitter = self.in_word_splitter | |||||
dataset = DataSet() | |||||
with open(filepath, 'r') as f: | |||||
words = [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0: # new line | |||||
if len(words) == 0: # 不能接受空行 | |||||
continue | |||||
line = ' '.join(words) | |||||
if cut_long_sent: | |||||
sents = cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for sent in sents: | |||||
instance = Instance(raw_sentence=sent) | |||||
dataset.append(instance) | |||||
words = [] | |||||
else: | |||||
line = line.split()[0] | |||||
if in_word_splitter is None: | |||||
words.append(line) | |||||
else: | |||||
words.append(line.split(in_word_splitter)[0]) | |||||
return dataset | |||||
class NaiveCWSReader(DataSetLoader): | class NaiveCWSReader(DataSetLoader): | ||||
""" | """ | ||||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | 这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | ||||
例如:: | |||||
这是 fastNLP , 一个 非常 good 的 包 . | 这是 fastNLP , 一个 非常 good 的 包 . | ||||
或者,即每个part后面还有一个pos tag | 或者,即每个part后面还有一个pos tag | ||||
例如:: | |||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | ||||
""" | """ | ||||
def __init__(self, in_word_splitter=None): | def __init__(self, in_word_splitter=None): | ||||
super().__init__() | |||||
super(NaiveCWSReader, self).__init__() | |||||
self.in_word_splitter = in_word_splitter | self.in_word_splitter = in_word_splitter | ||||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | ||||
@@ -680,8 +636,10 @@ class NaiveCWSReader(DataSetLoader): | |||||
和 | 和 | ||||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | ||||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | ||||
:param filepath: | :param filepath: | ||||
:param in_word_splitter: | :param in_word_splitter: | ||||
:param cut_long_sent: | |||||
:return: | :return: | ||||
""" | """ | ||||
if in_word_splitter == None: | if in_word_splitter == None: | ||||
@@ -740,7 +698,9 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
class ZhConllPOSReader(object): | class ZhConllPOSReader(object): | ||||
# 中文colln格式reader | |||||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -750,6 +710,8 @@ class ZhConllPOSReader(object): | |||||
words:list of str, | words:list of str, | ||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | ||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | ||||
:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | 1 编者按 编者按 NN O 11 nmod:topic | ||||
2 : : PU O 11 punct | 2 : : PU O 11 punct | ||||
3 7月 7月 NT DATE 4 compound:nn | 3 7月 7月 NT DATE 4 compound:nn | ||||
@@ -761,6 +723,7 @@ class ZhConllPOSReader(object): | |||||
3 飞行 飞行 NN O 8 nsubj | 3 飞行 飞行 NN O 8 nsubj | ||||
4 从 从 P O 5 case | 4 从 从 P O 5 case | ||||
5 外型 外型 NN O 8 nmod:prep | 5 外型 外型 NN O 8 nmod:prep | ||||
""" | """ | ||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
@@ -815,67 +778,10 @@ class ZhConllPOSReader(object): | |||||
return text, pos_tags | return text, pos_tags | ||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
if len(word) == 1: | |||||
char_seq.append(word) | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word) > 1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word) - 2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
char_seq.extend(list(word)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
pos_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
return text, pos_tags | |||||
class ConllxDataLoader(object): | class ConllxDataLoader(object): | ||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||||
""" | |||||
def load(self, path): | def load(self, path): | ||||
""" | """ | ||||
@@ -1,3 +1,7 @@ | |||||
""" | |||||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||||
""" | |||||
import copy | import copy | ||||
import json | import json | ||||
import math | import math | ||||
@@ -220,7 +224,23 @@ class BertPooler(nn.Module): | |||||
class BertModel(nn.Module): | class BertModel(nn.Module): | ||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). | |||||
"""Bidirectional Embedding Representations from Transformers. | |||||
If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
Construct a BERT model with pre-trained weights:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
""" | """ | ||||
@@ -1,18 +1,20 @@ | |||||
import copy | |||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from collections import defaultdict | |||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.utils import seq_mask | |||||
from fastNLP.core.losses import LossFunc | from fastNLP.core.losses import LossFunc | ||||
from fastNLP.core.metrics import MetricBase | from fastNLP.core.metrics import MetricBase | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.utils import seq_mask | |||||
def mst(scores): | def mst(scores): | ||||
""" | """ | ||||
@@ -1,5 +1,5 @@ | |||||
[train] | [train] | ||||
n_epochs = 1 | |||||
n_epochs = 20 | |||||
batch_size = 32 | batch_size = 32 | ||||
use_cuda = true | use_cuda = true | ||||
use_tqdm=true | use_tqdm=true | ||||
@@ -4,7 +4,7 @@ from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.core.utils import ClassPreprocess as Preprocess | from fastNLP.core.utils import ClassPreprocess as Preprocess | ||||
from fastNLP.io.config_io import ConfigLoader | from fastNLP.io.config_io import ConfigLoader | ||||
from fastNLP.io.config_io import ConfigSection | from fastNLP.io.config_io import ConfigSection | ||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
@@ -1,9 +1,12 @@ | |||||
import random | import random | ||||
import unittest | import unittest | ||||
from fastNLP import Vocabulary | |||||
import numpy as np | |||||
from fastNLP import Vocabulary, Instance | |||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ | ||||
IndexerProcessor, VocabProcessor, SeqLenProcessor | |||||
IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \ | |||||
SetInputProcessor, VocabIndexerProcessor | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -53,3 +56,46 @@ class TestProcessor(unittest.TestCase): | |||||
ds = proc(ds) | ds = proc(ds) | ||||
for data in ds.field_arrays["len"].content: | for data in ds.field_arrays["len"].content: | ||||
self.assertEqual(data, 30) | self.assertEqual(data, 30) | ||||
def test_ModelProcessor(self): | |||||
from fastNLP.models.cnn_text_classification import CNNText | |||||
model = CNNText(100, 100, 5) | |||||
ins_list = [] | |||||
for _ in range(64): | |||||
seq_len = np.random.randint(5, 30) | |||||
ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len)) | |||||
data_set = DataSet(ins_list) | |||||
data_set.set_input("word_seq", "seq_lens") | |||||
proc = ModelProcessor(model) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("pred" in data_set) | |||||
def test_Index2WordProcessor(self): | |||||
vocab = Vocabulary() | |||||
vocab.add_word_lst(["a", "b", "c", "d", "e"]) | |||||
proc = Index2WordProcessor(vocab, "tag_id", "tag") | |||||
data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])]) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("tag" in data_set) | |||||
def test_SetTargetProcessor(self): | |||||
proc = SetTargetProcessor("a", "b", "c") | |||||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||||
data_set = proc(data_set) | |||||
self.assertTrue(data_set["a"].is_target) | |||||
self.assertTrue(data_set["b"].is_target) | |||||
self.assertTrue(data_set["c"].is_target) | |||||
def test_SetInputProcessor(self): | |||||
proc = SetInputProcessor("a", "b", "c") | |||||
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | |||||
data_set = proc(data_set) | |||||
self.assertTrue(data_set["a"].is_input) | |||||
self.assertTrue(data_set["b"].is_input) | |||||
self.assertTrue(data_set["c"].is_input) | |||||
def test_VocabIndexerProcessor(self): | |||||
proc = VocabIndexerProcessor("word_seq", "word_ids") | |||||
data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])]) | |||||
data_set = proc(data_set) | |||||
self.assertTrue("word_ids" in data_set) |
@@ -138,6 +138,7 @@ class TestCase1(unittest.TestCase): | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
time.sleep(pause_seconds) | time.sleep(pause_seconds) | ||||
""" | |||||
def test_multi_workers_batch(self): | def test_multi_workers_batch(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
pause_seconds = 0.01 | pause_seconds = 0.01 | ||||
@@ -154,7 +155,8 @@ class TestCase1(unittest.TestCase): | |||||
end1 = time.time() | end1 = time.time() | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
time.sleep(pause_seconds) | time.sleep(pause_seconds) | ||||
""" | |||||
""" | |||||
def test_pin_memory(self): | def test_pin_memory(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
pause_seconds = 0.01 | pause_seconds = 0.01 | ||||
@@ -172,3 +174,4 @@ class TestCase1(unittest.TestCase): | |||||
# 这里发生OOM | # 这里发生OOM | ||||
# for batch_x, batch_y in batch: | # for batch_x, batch_y in batch: | ||||
# time.sleep(pause_seconds) | # time.sleep(pause_seconds) | ||||
""" |
@@ -237,6 +237,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
print_every=2) | print_every=2) | ||||
""" | |||||
def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
@@ -264,4 +265,4 @@ class TrainerTestGround(unittest.TestCase): | |||||
timeout=0, | timeout=0, | ||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
""" |
@@ -1,2 +0,0 @@ | |||||
迈向充满希望的新世纪——一九九八年新年讲话 | |||||
(附图片1张) |
@@ -0,0 +1,100 @@ | |||||
1 上海 _ NR NR _ 3 nsubj _ _ | |||||
2 积极 _ AD AD _ 3 advmod _ _ | |||||
3 准备 _ VV VV _ 0 root _ _ | |||||
4 迎接 _ VV VV _ 3 ccomp _ _ | |||||
5 欧元 _ NN NN _ 6 nn _ _ | |||||
6 诞生 _ NN NN _ 4 dobj _ _ | |||||
1 新华社 _ NR NR _ 7 dep _ _ | |||||
2 上海 _ NR NR _ 7 dep _ _ | |||||
3 十二月 _ NT NT _ 7 dep _ _ | |||||
4 三十日 _ NT NT _ 7 dep _ _ | |||||
5 电 _ NN NN _ 7 dep _ _ | |||||
6 ( _ PU PU _ 7 punct _ _ | |||||
7 记者 _ NN NN _ 0 root _ _ | |||||
8 潘清 _ NR NR _ 7 dep _ _ | |||||
9 ) _ PU PU _ 7 punct _ _ | |||||
1 即将 _ AD AD _ 2 advmod _ _ | |||||
2 诞生 _ VV VV _ 4 rcmod _ _ | |||||
3 的 _ DEC DEC _ 2 cpm _ _ | |||||
4 欧元 _ NN NN _ 6 nsubj _ _ | |||||
5 , _ PU PU _ 6 punct _ _ | |||||
6 引起 _ VV VV _ 0 root _ _ | |||||
7 了 _ AS AS _ 6 asp _ _ | |||||
8 上海 _ NR NR _ 14 nn _ _ | |||||
9 这 _ DT DT _ 14 det _ _ | |||||
10 个 _ M M _ 9 clf _ _ | |||||
11 中国 _ NR NR _ 13 nn _ _ | |||||
12 金融 _ NN NN _ 13 nn _ _ | |||||
13 中心 _ NN NN _ 14 nn _ _ | |||||
14 城市 _ NN NN _ 16 assmod _ _ | |||||
15 的 _ DEG DEG _ 14 assm _ _ | |||||
16 关注 _ NN NN _ 6 dobj _ _ | |||||
17 。 _ PU PU _ 6 punct _ _ | |||||
1 上海 _ NR NR _ 2 nn _ _ | |||||
2 银行界 _ NN NN _ 4 nsubj _ _ | |||||
3 纷纷 _ AD AD _ 4 advmod _ _ | |||||
4 推出 _ VV VV _ 0 root _ _ | |||||
5 了 _ AS AS _ 4 asp _ _ | |||||
6 与 _ P P _ 8 prep _ _ | |||||
7 之 _ PN PN _ 6 pobj _ _ | |||||
8 相关 _ VA VA _ 15 rcmod _ _ | |||||
9 的 _ DEC DEC _ 8 cpm _ _ | |||||
10 外汇 _ NN NN _ 15 nn _ _ | |||||
11 业务 _ NN NN _ 15 nn _ _ | |||||
12 品种 _ NN NN _ 15 conj _ _ | |||||
13 和 _ CC CC _ 15 cc _ _ | |||||
14 服务 _ NN NN _ 15 nn _ _ | |||||
15 举措 _ NN NN _ 4 dobj _ _ | |||||
16 , _ PU PU _ 4 punct _ _ | |||||
17 积极 _ AD AD _ 18 advmod _ _ | |||||
18 准备 _ VV VV _ 4 dep _ _ | |||||
19 启动 _ VV VV _ 18 ccomp _ _ | |||||
20 欧元 _ NN NN _ 21 nn _ _ | |||||
21 业务 _ NN NN _ 19 dobj _ _ | |||||
22 。 _ PU PU _ 4 punct _ _ | |||||
1 一些 _ CD CD _ 8 nummod _ _ | |||||
2 热衷于 _ VV VV _ 8 rcmod _ _ | |||||
3 个人 _ NN NN _ 5 nn _ _ | |||||
4 外汇 _ NN NN _ 5 nn _ _ | |||||
5 交易 _ NN NN _ 2 dobj _ _ | |||||
6 的 _ DEC DEC _ 2 cpm _ _ | |||||
7 上海 _ NR NR _ 8 nn _ _ | |||||
8 市民 _ NN NN _ 13 nsubj _ _ | |||||
9 , _ PU PU _ 13 punct _ _ | |||||
10 也 _ AD AD _ 13 advmod _ _ | |||||
11 对 _ P P _ 13 prep _ _ | |||||
12 欧元 _ NN NN _ 11 pobj _ _ | |||||
13 表示 _ VV VV _ 0 root _ _ | |||||
14 出 _ VV VV _ 13 rcomp _ _ | |||||
15 极 _ AD AD _ 16 advmod _ _ | |||||
16 大 _ VA VA _ 18 rcmod _ _ | |||||
17 的 _ DEC DEC _ 16 cpm _ _ | |||||
18 兴趣 _ NN NN _ 13 dobj _ _ | |||||
19 。 _ PU PU _ 13 punct _ _ | |||||
1 继 _ P P _ 38 prep _ _ | |||||
2 上海 _ NR NR _ 6 nn _ _ | |||||
3 大众 _ NR NR _ 6 nn _ _ | |||||
4 汽车 _ NN NN _ 6 nn _ _ | |||||
5 有限 _ JJ JJ _ 6 amod _ _ | |||||
6 公司 _ NN NN _ 13 nsubj _ _ | |||||
7 十八日 _ NT NT _ 13 tmod _ _ | |||||
8 在 _ P P _ 13 prep _ _ | |||||
9 中国 _ NR NR _ 10 nn _ _ | |||||
10 银行 _ NN NN _ 12 nn _ _ | |||||
11 上海 _ NR NR _ 12 nn _ _ | |||||
12 分行 _ NN NN _ 8 pobj _ _ | |||||
13 开立 _ VV VV _ 19 lccomp _ _ | |||||
14 上海 _ NR NR _ 16 dep _ _ | |||||
15 第一 _ OD OD _ 16 ordmod _ _ | |||||
16 个 _ M M _ 18 clf _ _ | |||||
17 欧元 _ NN NN _ 18 nn _ _ | |||||
18 帐户 _ NN NN _ 13 dobj _ _ | |||||
19 后 _ LC LC _ 1 plmod _ _ | |||||
20 , _ PU PU _ 38 punct _ _ | |||||
21 工商 _ NN NN _ 28 nn _ _ | |||||
22 银行 _ NN NN _ 28 conj _ _ |
@@ -1,24 +1,27 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.io.dataset_loader import Conll2003Loader | |||||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \ | |||||
ZhConllPOSReader, ConllxDataLoader | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
def test_case_1(self): | |||||
''' | |||||
def test_Conll2003Loader(self): | |||||
""" | |||||
Test the the loader of Conll2003 dataset | Test the the loader of Conll2003 dataset | ||||
''' | |||||
""" | |||||
dataset_path = "test/data_for_tests/conll_2003_example.txt" | dataset_path = "test/data_for_tests/conll_2003_example.txt" | ||||
loader = Conll2003Loader() | loader = Conll2003Loader() | ||||
dataset_2003 = loader.load(dataset_path) | dataset_2003 = loader.load(dataset_path) | ||||
for item in dataset_2003: | |||||
len0 = len(item["label0_list"]) | |||||
len1 = len(item["label1_list"]) | |||||
len2 = len(item["label2_list"]) | |||||
lentoken = len(item["token_list"]) | |||||
self.assertNotEqual(len0, 0) | |||||
self.assertEqual(len0, len1) | |||||
self.assertEqual(len1, len2) | |||||
def test_PeopleDailyCorpusLoader(self): | |||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | |||||
def test_ConllCWSReader(self): | |||||
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt") | |||||
def test_ZhConllPOSReader(self): | |||||
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx") | |||||
def test_ConllxDataLoader(self): | |||||
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx") |