Browse Source

Merge branch 'dev' of github.com:choosewhatulike/fastNLP-private into dev

tags/v0.4.10
yh 5 years ago
parent
commit
853bea5812
26 changed files with 422 additions and 3624 deletions
  1. +5
    -0
      codecov.yml
  2. +3
    -2
      docs/source/tutorials/fastnlp_10tmin_tutorial.rst
  3. +2
    -0
      docs/source/tutorials/fastnlp_1_minute_tutorial.rst
  4. +5
    -0
      docs/source/tutorials/fastnlp_advanced_tutorial.rst
  5. +5
    -0
      docs/source/tutorials/fastnlp_developer_guide.rst
  6. +1
    -0
      docs/source/user/installation.rst
  7. +2
    -0
      docs/source/user/quickstart.rst
  8. +11
    -10
      fastNLP/api/README.md
  9. +66
    -32
      fastNLP/api/api.py
  10. +27
    -4
      fastNLP/api/examples.py
  11. +8
    -1
      fastNLP/api/processor.py
  12. +4
    -0
      fastNLP/core/dataset.py
  13. +6
    -0
      fastNLP/io/base_loader.py
  14. +1
    -8
      fastNLP/io/config_io.py
  15. +73
    -167
      fastNLP/io/dataset_loader.py
  16. +21
    -1
      fastNLP/models/bert.py
  17. +10
    -8
      fastNLP/models/biaffine_parser.py
  18. +1
    -1
      reproduction/Biaffine_parser/cfg.cfg
  19. +1
    -1
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  20. +48
    -2
      test/api/test_processor.py
  21. +4
    -1
      test/core/test_batch.py
  22. +2
    -1
      test/core/test_trainer.py
  23. +0
    -3370
      test/data_for_tests/charlm.txt
  24. +0
    -2
      test/data_for_tests/people_infer.txt
  25. +100
    -0
      test/data_for_tests/zh_sample.conllx
  26. +16
    -13
      test/io/test_dataset_loader.py

+ 5
- 0
codecov.yml View File

@@ -0,0 +1,5 @@
ignore:
- "reproduction" # ignore folders and all its contents
- "setup.py"
- "docs"
- "tutorials"

+ 3
- 2
docs/source/tutorials/fastnlp_10tmin_tutorial.rst View File

@@ -1,7 +1,8 @@

fastNLP上手教程
fastNLP 10分钟上手教程
=============== ===============


教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb

fastNLP提供方便的数据预处理,训练和测试模型的功能 fastNLP提供方便的数据预处理,训练和测试模型的功能


DataSet & Instance DataSet & Instance


+ 2
- 0
docs/source/tutorials/fastnlp_1_minute_tutorial.rst View File

@@ -2,6 +2,8 @@
FastNLP 1分钟上手教程 FastNLP 1分钟上手教程
===================== =====================


教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb

step 1 step 1
------ ------




+ 5
- 0
docs/source/tutorials/fastnlp_advanced_tutorial.rst View File

@@ -0,0 +1,5 @@
fastNLP 进阶教程
===============

教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb


+ 5
- 0
docs/source/tutorials/fastnlp_developer_guide.rst View File

@@ -0,0 +1,5 @@
fastNLP 开发者指南
===============

原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md


+ 1
- 0
docs/source/user/installation.rst View File

@@ -5,6 +5,7 @@ Installation
.. contents:: .. contents::
:local: :local:


Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt .


Run the following commands to install fastNLP package: Run the following commands to install fastNLP package:




+ 2
- 0
docs/source/user/quickstart.rst View File

@@ -6,4 +6,6 @@ Quickstart


../tutorials/fastnlp_1_minute_tutorial ../tutorials/fastnlp_1_minute_tutorial
../tutorials/fastnlp_10tmin_tutorial ../tutorials/fastnlp_10tmin_tutorial
../tutorials/fastnlp_advanced_tutorial
../tutorials/fastnlp_developer_guide



+ 11
- 10
fastNLP/api/README.md View File

@@ -18,26 +18,27 @@ print(cws.predict(text))
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] # ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?']
``` ```


### 中文分词+词性标注
### 词性标注
```python ```python
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# 输入已分词序列
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司',
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']]
from fastNLP.api import POS from fastNLP.api import POS
pos = POS(device='cpu') pos = POS(device='cpu')
print(pos.predict(text)) print(pos.predict(text))
# [['编者/NN', '按/P', ':/PU', '7月/NT', '12日/NR', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一/OD', '款高/NN', '科技/NN', '隐形/NN', '无/VE', '人机/NN', '雷电/NN', '之/DEG', '神/NN', '。/PU'], ['这/DT', '款/NN', '飞行/VV', '从/P', '外型/NN', '上/LC', '来/MSP', '看/VV', '酷似/VV', '电影/NN', '中/LC', '的/DEG', '太空/NN', '飞行器/NN', ',/PU', '据/P', '英国/NR', '方面/NN', '介绍/VV', ',/PU', '可以/VV', '实现/VV', '洲际/NN', '远程/NN', '打击/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无/VE', '人机/NN', '到底/AD', '有/VE', '多/CD', '厉害/NN', '?/PU']]
# [['编者/NN', '按:/NN', '7月/NT', '12日/NT', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一款/NN', '高科技/NN', '隐形/AD', '无人机/VV', '雷电之神/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无人机/VV', '到底/AD', '有/VE', '多/AD', '厉害/VA', '?/PU']]
``` ```


### 中文分词+词性标注+句法分析
### 句法分析
```python ```python
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
text = [['编者', '按:', '7月', '12日', '', '英国', '航空', '航天', '系统', '公司', '公布', '', '', '公司',
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
['那么', '', '', '无人机', '到底', '', '', '厉害', '?']]
from fastNLP.api import Parser from fastNLP.api import Parser
parser = Parser(device='cpu') parser = Parser(device='cpu')
print(parser.predict(text)) print(parser.predict(text))
# [['12/nsubj', '12/prep', '2/punct', '5/nn', '2/pobj', '12/punct', '11/nn', '11/nn', '11/nn', '11/nn', '2/pobj', '0/root', '12/asp', '15/det', '16/nsubj', '21/rcmod', '16/cpm', '21/nummod', '21/nn', '21/nn', '22/top', '12/ccomp', '24/nn', '26/assmod', '24/assm', '22/dobj', '12/punct'], ['2/det', '8/xsubj', '8/mmod', '8/prep', '6/lobj', '4/plmod', '8/prtmod', '0/root', '8/ccomp', '11/lobj', '14/assmod', '11/assm', '14/nn', '9/dobj', '8/punct', '22/prep', '18/nn', '19/nsubj', '16/pccomp', '22/punct', '22/mmod', '8/dep', '25/nn', '25/nn', '22/dobj', '8/punct'], ['4/advmod', '3/det', '4/nsubj', '0/root', '4/dobj', '7/advmod', '4/conj', '9/nummod', '7/dobj', '4/punct']]
# [['2/nn', '4/nn', '4/nn', '20/tmod', '11/punct', '10/nn', '10/nn', '10/nn', '10/nn', '11/nsubj', '20/dep', '11/asp', '14/det', '15/nsubj', '18/rcmod', '15/cpm', '18/nn', '11/dobj', '20/advmod', '0/root', '20/dobj', '20/punct'], ['4/advmod', '3/det', '8/xsubj', '8/dep', '8/advmod', '8/dep', '8/advmod', '0/root', '8/punct']]
``` ```


完整样例见`examples.py` 完整样例见`examples.py`

+ 66
- 32
fastNLP/api/api.py View File

@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet


from fastNLP.api.utils import load_url from fastNLP.api.utils import load_url
from fastNLP.api.processor import ModelProcessor from fastNLP.api.processor import ModelProcessor
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
@@ -17,9 +17,9 @@ from fastNLP.api.processor import IndexerProcessor


# TODO add pretrain urls # TODO add pretrain urls
model_urls = { model_urls = {
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
} }




@@ -90,38 +90,28 @@ class POS(API):
# 3. 使用pipeline # 3. 使用pipeline
self.pipeline(dataset) self.pipeline(dataset)


# def decode_tags(ins):
# pred_tags = ins["tag"]
# chars = ins["words"]
# words = []
# start_idx = 0
# for idx, tag in enumerate(pred_tags):
# if tag[0] == "S":
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
# start_idx = idx + 1
# elif tag[0] == "E":
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
# start_idx = idx + 1
# return words
#
# dataset.apply(decode_tags, new_field_name="tag_output")
def merge_tag(words_list, tags_list):
rtn = []
for words, tags in zip(words_list, tags_list):
rtn.append([w + "/" + t for w, t in zip(words, tags)])
return rtn


output = dataset.field_arrays["tag"].content output = dataset.field_arrays["tag"].content
if isinstance(content, str): if isinstance(content, str):
return output[0] return output[0]
elif isinstance(content, list): elif isinstance(content, list):
return output
return merge_tag(content, output)


def test(self, file_path): def test(self, file_path):
test_data = ConllxDataLoader().load(file_path) test_data = ConllxDataLoader().load(file_path)


with open("model_pp_0117.pkl", "rb") as f:
save_dict = torch.load(f)
save_dict = self._dict
tag_vocab = save_dict["tag_vocab"] tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"] pipeline = save_dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline pipeline.pipeline = [index_tag] + pipeline.pipeline


test_data.rename_field("pos_tags", "tag")
pipeline(test_data) pipeline(test_data)
test_data.set_target("truth") test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content prediction = test_data.field_arrays["predict"].content
@@ -235,7 +225,7 @@ class CWS(API):
rec = eval_res['BMESF1PreRecMetric']['rec'] rec = eval_res['BMESF1PreRecMetric']['rec']
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))


return f1, pre, rec
return {"F1": f1, "precision": pre, "recall": rec}




class Parser(API): class Parser(API):
@@ -260,6 +250,7 @@ class Parser(API):
dataset.add_field('wp', pos_out) dataset.add_field('wp', pos_out)
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
dataset.rename_field("words", "raw_words")


# 3. 使用pipeline # 3. 使用pipeline
self.pipeline(dataset) self.pipeline(dataset)
@@ -269,31 +260,74 @@ class Parser(API):
# output like: [['2/top', '0/root', '4/nn', '2/dep']] # output like: [['2/top', '0/root', '4/nn', '2/dep']]
return dataset.field_arrays['output'].content return dataset.field_arrays['output'].content


def test(self, filepath):
data = ConllxDataLoader().load(filepath)
ds = DataSet()
for ins1, ins2 in zip(add_seg_tag(data), data):
ds.append(Instance(words=ins1[0], tag=ins1[1],
gold_words=ins2[0], gold_pos=ins2[1],
gold_heads=ins2[2], gold_head_tags=ins2[3]))
def load_test_file(self, path):
def get_one(sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]

datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

data = [get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))
return data_list


def test(self, filepath):
data = self.load_test_file(filepath)

def convert(data):
BOS = '<BOS>'
dataset = DataSet()
for sample in data:
word_seq = [BOS] + sample[0]
pos_seq = [BOS] + sample[1]
heads = [0] + sample[2]
head_tags = [BOS] + sample[3]
dataset.append(Instance(raw_words=word_seq,
pos=pos_seq,
gold_heads=heads,
arc_true=heads,
tags=head_tags))
return dataset

ds = convert(data)
pp = self.pipeline pp = self.pipeline
for p in pp: for p in pp:
if p.field_name == 'word_list': if p.field_name == 'word_list':
p.field_name = 'gold_words' p.field_name = 'gold_words'
elif p.field_name == 'pos_list': elif p.field_name == 'pos_list':
p.field_name = 'gold_pos' p.field_name = 'gold_pos'
# ds.rename_field("words", "raw_words")
# ds.rename_field("tag", "pos")
pp(ds) pp(ds)
head_cor, label_cor, total = 0, 0, 0 head_cor, label_cor, total = 0, 0, 0
for ins in ds: for ins in ds:
head_gold = ins['gold_heads'] head_gold = ins['gold_heads']
head_pred = ins['heads']
head_pred = ins['arc_pred']
length = len(head_gold) length = len(head_gold)
total += length total += length
for i in range(length): for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0 head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor / total uas = head_cor / total
print('uas:{:.2f}'.format(uas))
# print('uas:{:.2f}'.format(uas))


for p in pp: for p in pp:
if p.field_name == 'gold_words': if p.field_name == 'gold_words':
@@ -301,7 +335,7 @@ class Parser(API):
elif p.field_name == 'gold_pos': elif p.field_name == 'gold_pos':
p.field_name = 'pos_list' p.field_name = 'pos_list'


return uas
return {"USA": round(uas, 5)}




class Analyzer: class Analyzer:


+ 27
- 4
fastNLP/api/examples.py View File

@@ -15,19 +15,42 @@ def chinese_word_segmentation():
print(cws.predict(text)) print(cws.predict(text))




def chinese_word_segmentation_test():
cws = CWS(device='cpu')
print(cws.test("../../test/data_for_tests/zh_sample.conllx"))


def pos_tagging(): def pos_tagging():
# 输入已分词序列 # 输入已分词序列
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
text = [text[0].split()]
print(text)
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司',
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']]
pos = POS(device='cpu') pos = POS(device='cpu')
print(pos.predict(text)) print(pos.predict(text))




def pos_tagging_test():
pos = POS(device='cpu')
print(pos.test("../../test/data_for_tests/zh_sample.conllx"))


def syntactic_parsing(): def syntactic_parsing():
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司',
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']]
parser = Parser(device='cpu') parser = Parser(device='cpu')
print(parser.predict(text)) print(parser.predict(text))




def syntactic_parsing_test():
parser = Parser(device='cpu')
print(parser.test("../../test/data_for_tests/zh_sample.conllx"))


if __name__ == "__main__": if __name__ == "__main__":
pos_tagging()
# chinese_word_segmentation()
# chinese_word_segmentation_test()
# pos_tagging()
# pos_tagging_test()
syntactic_parsing()
# syntactic_parsing_test()

+ 8
- 1
fastNLP/api/processor.py View File

@@ -102,6 +102,7 @@ class PreAppendProcessor(Processor):
[data] + instance[field_name] [data] + instance[field_name]


""" """

def __init__(self, data, field_name, new_added_field_name=None): def __init__(self, data, field_name, new_added_field_name=None):
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) super(PreAppendProcessor, self).__init__(field_name, new_added_field_name)
self.data = data self.data = data
@@ -116,6 +117,7 @@ class SliceProcessor(Processor):
从某个field中只取部分内容。等价于instance[field_name][start:end:step] 从某个field中只取部分内容。等价于instance[field_name][start:end:step]


""" """

def __init__(self, start, end, step, field_name, new_added_field_name=None): def __init__(self, start, end, step, field_name, new_added_field_name=None):
super(SliceProcessor, self).__init__(field_name, new_added_field_name) super(SliceProcessor, self).__init__(field_name, new_added_field_name)
for o in (start, end, step): for o in (start, end, step):
@@ -132,6 +134,7 @@ class Num2TagProcessor(Processor):
将一句话中的数字转换为某个tag。 将一句话中的数字转换为某个tag。


""" """

def __init__(self, tag, field_name, new_added_field_name=None): def __init__(self, tag, field_name, new_added_field_name=None):
""" """


@@ -163,6 +166,7 @@ class IndexerProcessor(Processor):
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如
['我', '是', xxx] ['我', '是', xxx]
""" """

def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):


assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@@ -215,6 +219,7 @@ class SeqLenProcessor(Processor):
根据某个field新增一个sequence length的field。取该field的第一维 根据某个field新增一个sequence length的field。取该field的第一维


""" """

def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input self.is_input = is_input
@@ -229,6 +234,7 @@ class SeqLenProcessor(Processor):


from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args



class ModelProcessor(Processor): class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
""" """
@@ -292,6 +298,7 @@ class Index2WordProcessor(Processor):
将DataSet中某个为index的field根据vocab转换为str 将DataSet中某个为index的field根据vocab转换为str


""" """

def __init__(self, vocab, field_name, new_added_field_name): def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab self.vocab = vocab
@@ -303,7 +310,6 @@ class Index2WordProcessor(Processor):




class SetTargetProcessor(Processor): class SetTargetProcessor(Processor):
# TODO; remove it.
def __init__(self, *fields, flag=True): def __init__(self, *fields, flag=True):
super(SetTargetProcessor, self).__init__(None, None) super(SetTargetProcessor, self).__init__(None, None)
self.fields = fields self.fields = fields
@@ -313,6 +319,7 @@ class SetTargetProcessor(Processor):
dataset.set_target(*self.fields, flag=self.flag) dataset.set_target(*self.fields, flag=self.flag)
return dataset return dataset



class SetInputProcessor(Processor): class SetInputProcessor(Processor):
def __init__(self, *fields, flag=True): def __init__(self, *fields, flag=True):
super(SetInputProcessor, self).__init__(None, None) super(SetInputProcessor, self).__init__(None, None)


+ 4
- 0
fastNLP/core/dataset.py View File

@@ -92,6 +92,10 @@ class DataSet(object):
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target) is_input=field.is_input, is_target=field.is_target)
return data_set return data_set
elif isinstance(idx, str):
if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx))
return self.field_arrays[idx]
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))




+ 6
- 0
fastNLP/io/base_loader.py View File

@@ -11,18 +11,24 @@ class BaseLoader(object):


@staticmethod @staticmethod
def load_lines(data_path): def load_lines(data_path):
"""按行读取,舍弃每行两侧空白字符,返回list of str
"""
with open(data_path, "r", encoding="utf=8") as f: with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines() text = f.readlines()
return [line.strip() for line in text] return [line.strip() for line in text]


@classmethod @classmethod
def load(cls, data_path): def load(cls, data_path):
"""先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
"""
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines() text = f.readlines()
return [[word for word in sent.strip()] for sent in text] return [[word for word in sent.strip()] for sent in text]


@classmethod @classmethod
def load_with_cache(cls, data_path, cache_path): def load_with_cache(cls, data_path, cache_path):
"""缓存版的load
"""
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path):
with open(cache_path, 'rb') as f: with open(cache_path, 'rb') as f:
return pickle.load(f) return pickle.load(f)


+ 1
- 8
fastNLP/io/config_io.py View File

@@ -11,7 +11,6 @@ class ConfigLoader(BaseLoader):
:param str data_path: path to the config :param str data_path: path to the config


""" """

def __init__(self, data_path=None): def __init__(self, data_path=None):
super(ConfigLoader, self).__init__() super(ConfigLoader, self).__init__()
if data_path is not None: if data_path is not None:
@@ -30,7 +29,7 @@ class ConfigLoader(BaseLoader):
Example:: Example::


test_args = ConfigSection() test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})


""" """
assert isinstance(sections, dict) assert isinstance(sections, dict)
@@ -202,8 +201,6 @@ class ConfigSaver(object):
continue continue


if '=' not in line: if '=' not in line:
# log = create_logger(__name__, './config_saver.log')
# log.error("can NOT load config file [%s]" % self.file_path)
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) raise RuntimeError("can NOT load config file {}".__format__(self.file_path))


key = line.split('=', maxsplit=1)[0].strip() key = line.split('=', maxsplit=1)[0].strip()
@@ -263,10 +260,6 @@ class ConfigSaver(object):
change_file = True change_file = True
break break
if section_file[k] != section[k]: if section_file[k] != section[k]:
# logger = create_logger(__name__, "./config_loader.log")
# logger.warning("section [%s] in config file [%s] has been changed" % (
# section_name, self.file_path
# ))
change_file = True change_file = True
break break
if not change_file: if not change_file:


+ 73
- 167
fastNLP/io/dataset_loader.py View File

@@ -126,8 +126,8 @@ class RawDataSetLoader(DataSetLoader):
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')




class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for a POS Tag dataset.
class DummyPOSReader(DataSetLoader):
"""A simple reader for a dummy POS tagging dataset.


In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second
Col is the label. Different sentence are divided by an empty line. Col is the label. Different sentence are divided by an empty line.
@@ -146,7 +146,7 @@ class POSDataSetLoader(DataSetLoader):
""" """


def __init__(self): def __init__(self):
super(POSDataSetLoader, self).__init__()
super(DummyPOSReader, self).__init__()


def load(self, data_path): def load(self, data_path):
""" """
@@ -194,16 +194,14 @@ class POSDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)




DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos')




class TokenizeDataSetLoader(DataSetLoader):
class DummyCWSReader(DataSetLoader):
"""Load pku dataset for Chinese word segmentation.
""" """
Data set loader for tokenization data sets
"""

def __init__(self): def __init__(self):
super(TokenizeDataSetLoader, self).__init__()
super(DummyCWSReader, self).__init__()


def load(self, data_path, max_seq_len=32): def load(self, data_path, max_seq_len=32):
"""Load pku dataset for Chinese word segmentation. """Load pku dataset for Chinese word segmentation.
@@ -256,11 +254,11 @@ class TokenizeDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)




class ClassDataSetLoader(DataSetLoader):
class DummyClassificationReader(DataSetLoader):
"""Loader for a dummy classification data set""" """Loader for a dummy classification data set"""


def __init__(self): def __init__(self):
super(ClassDataSetLoader, self).__init__()
super(DummyClassificationReader, self).__init__()


def load(self, data_path): def load(self, data_path):
assert os.path.exists(data_path) assert os.path.exists(data_path)
@@ -271,7 +269,7 @@ class ClassDataSetLoader(DataSetLoader):


@staticmethod @staticmethod
def parse(lines): def parse(lines):
"""
"""每行第一个token是标签,其余是字/词;由空格分隔。


:param lines: lines from dataset :param lines: lines from dataset
:return: list(list(list())): the three level of lists are words, sentence, and dataset :return: list(list(list())): the three level of lists are words, sentence, and dataset
@@ -327,16 +325,11 @@ class ConllLoader(DataSetLoader):
pass pass




class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.

class DummyLMReader(DataSetLoader):
"""A Dummy Language Model Dataset Reader
""" """

def __init__(self): def __init__(self):
super(LMDataSetLoader, self).__init__()
super(DummyLMReader, self).__init__()


def load(self, data_path): def load(self, data_path):
if not os.path.exists(data_path): if not os.path.exists(data_path):
@@ -364,19 +357,25 @@ class LMDataSetLoader(DataSetLoader):




class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
"""人民日报数据集
""" """
People Daily Corpus: Chinese word segmentation, POS tag, NER
"""

def __init__(self): def __init__(self):
super(PeopleDailyCorpusLoader, self).__init__() super(PeopleDailyCorpusLoader, self).__init__()
self.pos = True
self.ner = True


def load(self, data_path):
def load(self, data_path, pos=True, ner=True):
"""

:param str data_path: 数据路径
:param bool pos: 是否使用词性标签
:param bool ner: 是否使用命名实体标签
:return: a DataSet object
"""
self.pos, self.ner = pos, ner
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines() sents = f.readlines()

pos_tag_examples = []
ner_examples = []
examples = []
for sent in sents: for sent in sents:
if len(sent) <= 2: if len(sent) <= 2:
continue continue
@@ -410,40 +409,44 @@ class PeopleDailyCorpusLoader(DataSetLoader):
sent_ner.append(ner_tag) sent_ner.append(ner_tag)
sent_pos_tag.append(pos) sent_pos_tag.append(pos)
sent_words.append(token) sent_words.append(token)
pos_tag_examples.append([sent_words, sent_pos_tag])
ner_examples.append([sent_words, sent_ner])
# List[List[List[str], List[str]]]
# ner_examples not used
return self.convert(pos_tag_examples)
example = [sent_words]
if self.pos is True:
example.append(sent_pos_tag)
if self.ner is True:
example.append(sent_ner)
examples.append(example)
return self.convert(examples)


def convert(self, data): def convert(self, data):
data_set = DataSet() data_set = DataSet()
for item in data: for item in data:
sent_words, sent_pos_tag = item[0], item[1]
data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
data_set.set_target("tags")
data_set.set_input("sent_words")
data_set.set_input("seq_len")
sent_words = item[0]
if self.pos is True and self.ner is True:
instance = Instance(words=sent_words, pos_tags=item[1], ner=item[2])
elif self.pos is True:
instance = Instance(words=sent_words, pos_tags=item[1])
elif self.ner is True:
instance = Instance(words=sent_words, ner=item[1])
else:
instance = Instance(words=sent_words)
data_set.append(instance)
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
return data_set return data_set




class Conll2003Loader(DataSetLoader): class Conll2003Loader(DataSetLoader):
"""Self-defined loader of conll2003 dataset
"""Loader for conll2003 dataset
More information about the given dataset cound be found on More information about the given dataset cound be found on
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
""" """

def __init__(self): def __init__(self):
super(Conll2003Loader, self).__init__() super(Conll2003Loader, self).__init__()


def load(self, dataset_path): def load(self, dataset_path):
with open(dataset_path, "r", encoding="utf-8") as f: with open(dataset_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()

##Parse the dataset line by line
parsed_data = [] parsed_data = []
sentence = [] sentence = []
tokens = [] tokens = []
@@ -470,21 +473,20 @@ class Conll2003Loader(DataSetLoader):
lambda labels: labels[1], sample[1])) lambda labels: labels[1], sample[1]))
label2_list = list(map( label2_list = list(map(
lambda labels: labels[2], sample[1])) lambda labels: labels[2], sample[1]))
dataset.append(Instance(token_list=sample[0],
label0_list=label0_list,
label1_list=label1_list,
label2_list=label2_list))
dataset.append(Instance(tokens=sample[0],
pos=label0_list,
chucks=label1_list,
ner=label2_list))


return dataset return dataset




class SNLIDataSetLoader(DataSetLoader):
class SNLIDataSetReader(DataSetLoader):
"""A data set loader for SNLI data set. """A data set loader for SNLI data set.


""" """

def __init__(self): def __init__(self):
super(SNLIDataSetLoader, self).__init__()
super(SNLIDataSetReader, self).__init__()


def load(self, path_list): def load(self, path_list):
""" """
@@ -553,6 +555,8 @@ class ConllCWSReader(object):
""" """
返回的DataSet只包含raw_sentence这个field,内容为str。 返回的DataSet只包含raw_sentence这个field,内容为str。
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
::

1 编者按 编者按 NN O 11 nmod:topic 1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct 2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn 3 7月 7月 NT DATE 4 compound:nn
@@ -564,6 +568,7 @@ class ConllCWSReader(object):
3 飞行 飞行 NN O 8 nsubj 3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case 4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep 5 外型 外型 NN O 8 nmod:prep

""" """
datalist = [] datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
@@ -575,7 +580,7 @@ class ConllCWSReader(object):
elif line.startswith('#'): elif line.startswith('#'):
continue continue
else: else:
sample.append(line.split('\t'))
sample.append(line.strip().split())
if len(sample) > 0: if len(sample) > 0:
datalist.append(sample) datalist.append(sample)


@@ -592,7 +597,6 @@ class ConllCWSReader(object):
sents = [line] sents = [line]
for raw_sentence in sents: for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence)) ds.append(Instance(raw_sentence=raw_sentence))

return ds return ds


def get_char_lst(self, sample): def get_char_lst(self, sample):
@@ -607,70 +611,22 @@ class ConllCWSReader(object):
return text return text




class POSCWSReader(DataSetLoader):
"""
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限.
迈 N
向 N
充 N
...
泽 I-PER
民 I-PER

( N
一 N
九 N
...


:param filepath:
:return:
"""

def __init__(self, in_word_splitter=None):
super().__init__()
self.in_word_splitter = in_word_splitter

def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
if in_word_splitter is None:
in_word_splitter = self.in_word_splitter
dataset = DataSet()
with open(filepath, 'r') as f:
words = []
for line in f:
line = line.strip()
if len(line) == 0: # new line
if len(words) == 0: # 不能接受空行
continue
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
instance = Instance(raw_sentence=sent)
dataset.append(instance)
words = []
else:
line = line.split()[0]
if in_word_splitter is None:
words.append(line)
else:
words.append(line.split(in_word_splitter)[0])
return dataset


class NaiveCWSReader(DataSetLoader): class NaiveCWSReader(DataSetLoader):
""" """
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了
例如::

这是 fastNLP , 一个 非常 good 的 包 . 这是 fastNLP , 一个 非常 good 的 包 .

或者,即每个part后面还有一个pos tag 或者,即每个part后面还有一个pos tag
例如::

也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY

""" """


def __init__(self, in_word_splitter=None): def __init__(self, in_word_splitter=None):
super().__init__()

super(NaiveCWSReader, self).__init__()
self.in_word_splitter = in_word_splitter self.in_word_splitter = in_word_splitter


def load(self, filepath, in_word_splitter=None, cut_long_sent=False): def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
@@ -680,8 +636,10 @@ class NaiveCWSReader(DataSetLoader):
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]

:param filepath: :param filepath:
:param in_word_splitter: :param in_word_splitter:
:param cut_long_sent:
:return: :return:
""" """
if in_word_splitter == None: if in_word_splitter == None:
@@ -740,7 +698,9 @@ def cut_long_sentence(sent, max_sample_length=200):




class ZhConllPOSReader(object): class ZhConllPOSReader(object):
# 中文colln格式reader
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。

"""
def __init__(self): def __init__(self):
pass pass


@@ -750,6 +710,8 @@ class ZhConllPOSReader(object):
words:list of str, words:list of str,
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
::

1 编者按 编者按 NN O 11 nmod:topic 1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct 2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn 3 7月 7月 NT DATE 4 compound:nn
@@ -761,6 +723,7 @@ class ZhConllPOSReader(object):
3 飞行 飞行 NN O 8 nsubj 3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case 4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep 5 外型 外型 NN O 8 nmod:prep

""" """
datalist = [] datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
@@ -815,67 +778,10 @@ class ZhConllPOSReader(object):
return text, pos_tags return text, pos_tags




class ConllPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。
def __init__(self):
pass

def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
if len(word) == 1:
char_seq.append(word)
pos_seq.append('S-{}'.format(tag))
elif len(word) > 1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word) - 2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
char_seq.extend(list(word))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))
return ds

def get_one(self, sample):
if len(sample) == 0:
return None
text = []
pos_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
pos_tags.append(t2)
return text, pos_tags



class ConllxDataLoader(object): class ConllxDataLoader(object):
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。

"""
def load(self, path): def load(self, path):
""" """




+ 21
- 1
fastNLP/models/bert.py View File

@@ -1,3 +1,7 @@
"""
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.

"""
import copy import copy
import json import json
import math import math
@@ -220,7 +224,23 @@ class BertPooler(nn.Module):




class BertModel(nn.Module): class BertModel(nn.Module):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
"""Bidirectional Embedding Representations from Transformers.

If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT.
sources::

'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",


Construct a BERT model with pre-trained weights::

model = BertModel.from_pretrained("path/to/weights/directory")


""" """




+ 10
- 8
fastNLP/models/biaffine_parser.py View File

@@ -1,18 +1,20 @@
import copy
from collections import defaultdict

import numpy as np import numpy as np
import torch import torch
from collections import defaultdict
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.encoder.transformer import TransformerEncoder
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.utils import seq_mask

from fastNLP.core.losses import LossFunc from fastNLP.core.losses import LossFunc
from fastNLP.core.metrics import MetricBase from fastNLP.core.metrics import MetricBase
from fastNLP.core.utils import seq_lens_to_masks from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.modules.encoder.transformer import TransformerEncoder
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.utils import seq_mask



def mst(scores): def mst(scores):
""" """


+ 1
- 1
reproduction/Biaffine_parser/cfg.cfg View File

@@ -1,5 +1,5 @@
[train] [train]
n_epochs = 1
n_epochs = 20
batch_size = 32 batch_size = 32
use_cuda = true use_cuda = true
use_tqdm=true use_tqdm=true


+ 1
- 1
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -4,7 +4,7 @@ from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.utils import ClassPreprocess as Preprocess from fastNLP.core.utils import ClassPreprocess as Preprocess
from fastNLP.io.config_io import ConfigLoader from fastNLP.io.config_io import ConfigLoader
from fastNLP.io.config_io import ConfigSection from fastNLP.io.config_io import ConfigSection
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loader
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention from fastNLP.modules.aggregator.self_attention import SelfAttention
from fastNLP.modules.decoder.MLP import MLP from fastNLP.modules.decoder.MLP import MLP


+ 48
- 2
test/api/test_processor.py View File

@@ -1,9 +1,12 @@
import random import random
import unittest import unittest


from fastNLP import Vocabulary
import numpy as np

from fastNLP import Vocabulary, Instance
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \ from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \
IndexerProcessor, VocabProcessor, SeqLenProcessor
IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \
SetInputProcessor, VocabIndexerProcessor
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet




@@ -53,3 +56,46 @@ class TestProcessor(unittest.TestCase):
ds = proc(ds) ds = proc(ds)
for data in ds.field_arrays["len"].content: for data in ds.field_arrays["len"].content:
self.assertEqual(data, 30) self.assertEqual(data, 30)

def test_ModelProcessor(self):
from fastNLP.models.cnn_text_classification import CNNText
model = CNNText(100, 100, 5)
ins_list = []
for _ in range(64):
seq_len = np.random.randint(5, 30)
ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len))
data_set = DataSet(ins_list)
data_set.set_input("word_seq", "seq_lens")
proc = ModelProcessor(model)
data_set = proc(data_set)
self.assertTrue("pred" in data_set)

def test_Index2WordProcessor(self):
vocab = Vocabulary()
vocab.add_word_lst(["a", "b", "c", "d", "e"])
proc = Index2WordProcessor(vocab, "tag_id", "tag")
data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])])
data_set = proc(data_set)
self.assertTrue("tag" in data_set)

def test_SetTargetProcessor(self):
proc = SetTargetProcessor("a", "b", "c")
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
data_set = proc(data_set)
self.assertTrue(data_set["a"].is_target)
self.assertTrue(data_set["b"].is_target)
self.assertTrue(data_set["c"].is_target)

def test_SetInputProcessor(self):
proc = SetInputProcessor("a", "b", "c")
data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
data_set = proc(data_set)
self.assertTrue(data_set["a"].is_input)
self.assertTrue(data_set["b"].is_input)
self.assertTrue(data_set["c"].is_input)

def test_VocabIndexerProcessor(self):
proc = VocabIndexerProcessor("word_seq", "word_ids")
data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])])
data_set = proc(data_set)
self.assertTrue("word_ids" in data_set)

+ 4
- 1
test/core/test_batch.py View File

@@ -138,6 +138,7 @@ class TestCase1(unittest.TestCase):
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
time.sleep(pause_seconds) time.sleep(pause_seconds)


"""
def test_multi_workers_batch(self): def test_multi_workers_batch(self):
batch_size = 32 batch_size = 32
pause_seconds = 0.01 pause_seconds = 0.01
@@ -154,7 +155,8 @@ class TestCase1(unittest.TestCase):
end1 = time.time() end1 = time.time()
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
time.sleep(pause_seconds) time.sleep(pause_seconds)

"""
"""
def test_pin_memory(self): def test_pin_memory(self):
batch_size = 32 batch_size = 32
pause_seconds = 0.01 pause_seconds = 0.01
@@ -172,3 +174,4 @@ class TestCase1(unittest.TestCase):
# 这里发生OOM # 这里发生OOM
# for batch_x, batch_y in batch: # for batch_x, batch_y in batch:
# time.sleep(pause_seconds) # time.sleep(pause_seconds)
"""

+ 2
- 1
test/core/test_trainer.py View File

@@ -237,6 +237,7 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
print_every=2) print_every=2)


"""
def test_trainer_multiprocess(self): def test_trainer_multiprocess(self):
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True) dataset.set_input('x1', 'x2', 'y', flag=True)
@@ -264,4 +265,4 @@ class TrainerTestGround(unittest.TestCase):
timeout=0, timeout=0,
) )
trainer.train() trainer.train()
"""

+ 0
- 3370
test/data_for_tests/charlm.txt
File diff suppressed because it is too large
View File


+ 0
- 2
test/data_for_tests/people_infer.txt View File

@@ -1,2 +0,0 @@
迈向充满希望的新世纪——一九九八年新年讲话
(附图片1张)

+ 100
- 0
test/data_for_tests/zh_sample.conllx View File

@@ -0,0 +1,100 @@
1 上海 _ NR NR _ 3 nsubj _ _
2 积极 _ AD AD _ 3 advmod _ _
3 准备 _ VV VV _ 0 root _ _
4 迎接 _ VV VV _ 3 ccomp _ _
5 欧元 _ NN NN _ 6 nn _ _
6 诞生 _ NN NN _ 4 dobj _ _

1 新华社 _ NR NR _ 7 dep _ _
2 上海 _ NR NR _ 7 dep _ _
3 十二月 _ NT NT _ 7 dep _ _
4 三十日 _ NT NT _ 7 dep _ _
5 电 _ NN NN _ 7 dep _ _
6 ( _ PU PU _ 7 punct _ _
7 记者 _ NN NN _ 0 root _ _
8 潘清 _ NR NR _ 7 dep _ _
9 ) _ PU PU _ 7 punct _ _

1 即将 _ AD AD _ 2 advmod _ _
2 诞生 _ VV VV _ 4 rcmod _ _
3 的 _ DEC DEC _ 2 cpm _ _
4 欧元 _ NN NN _ 6 nsubj _ _
5 , _ PU PU _ 6 punct _ _
6 引起 _ VV VV _ 0 root _ _
7 了 _ AS AS _ 6 asp _ _
8 上海 _ NR NR _ 14 nn _ _
9 这 _ DT DT _ 14 det _ _
10 个 _ M M _ 9 clf _ _
11 中国 _ NR NR _ 13 nn _ _
12 金融 _ NN NN _ 13 nn _ _
13 中心 _ NN NN _ 14 nn _ _
14 城市 _ NN NN _ 16 assmod _ _
15 的 _ DEG DEG _ 14 assm _ _
16 关注 _ NN NN _ 6 dobj _ _
17 。 _ PU PU _ 6 punct _ _

1 上海 _ NR NR _ 2 nn _ _
2 银行界 _ NN NN _ 4 nsubj _ _
3 纷纷 _ AD AD _ 4 advmod _ _
4 推出 _ VV VV _ 0 root _ _
5 了 _ AS AS _ 4 asp _ _
6 与 _ P P _ 8 prep _ _
7 之 _ PN PN _ 6 pobj _ _
8 相关 _ VA VA _ 15 rcmod _ _
9 的 _ DEC DEC _ 8 cpm _ _
10 外汇 _ NN NN _ 15 nn _ _
11 业务 _ NN NN _ 15 nn _ _
12 品种 _ NN NN _ 15 conj _ _
13 和 _ CC CC _ 15 cc _ _
14 服务 _ NN NN _ 15 nn _ _
15 举措 _ NN NN _ 4 dobj _ _
16 , _ PU PU _ 4 punct _ _
17 积极 _ AD AD _ 18 advmod _ _
18 准备 _ VV VV _ 4 dep _ _
19 启动 _ VV VV _ 18 ccomp _ _
20 欧元 _ NN NN _ 21 nn _ _
21 业务 _ NN NN _ 19 dobj _ _
22 。 _ PU PU _ 4 punct _ _

1 一些 _ CD CD _ 8 nummod _ _
2 热衷于 _ VV VV _ 8 rcmod _ _
3 个人 _ NN NN _ 5 nn _ _
4 外汇 _ NN NN _ 5 nn _ _
5 交易 _ NN NN _ 2 dobj _ _
6 的 _ DEC DEC _ 2 cpm _ _
7 上海 _ NR NR _ 8 nn _ _
8 市民 _ NN NN _ 13 nsubj _ _
9 , _ PU PU _ 13 punct _ _
10 也 _ AD AD _ 13 advmod _ _
11 对 _ P P _ 13 prep _ _
12 欧元 _ NN NN _ 11 pobj _ _
13 表示 _ VV VV _ 0 root _ _
14 出 _ VV VV _ 13 rcomp _ _
15 极 _ AD AD _ 16 advmod _ _
16 大 _ VA VA _ 18 rcmod _ _
17 的 _ DEC DEC _ 16 cpm _ _
18 兴趣 _ NN NN _ 13 dobj _ _
19 。 _ PU PU _ 13 punct _ _

1 继 _ P P _ 38 prep _ _
2 上海 _ NR NR _ 6 nn _ _
3 大众 _ NR NR _ 6 nn _ _
4 汽车 _ NN NN _ 6 nn _ _
5 有限 _ JJ JJ _ 6 amod _ _
6 公司 _ NN NN _ 13 nsubj _ _
7 十八日 _ NT NT _ 13 tmod _ _
8 在 _ P P _ 13 prep _ _
9 中国 _ NR NR _ 10 nn _ _
10 银行 _ NN NN _ 12 nn _ _
11 上海 _ NR NR _ 12 nn _ _
12 分行 _ NN NN _ 8 pobj _ _
13 开立 _ VV VV _ 19 lccomp _ _
14 上海 _ NR NR _ 16 dep _ _
15 第一 _ OD OD _ 16 ordmod _ _
16 个 _ M M _ 18 clf _ _
17 欧元 _ NN NN _ 18 nn _ _
18 帐户 _ NN NN _ 13 dobj _ _
19 后 _ LC LC _ 1 plmod _ _
20 , _ PU PU _ 38 punct _ _
21 工商 _ NN NN _ 28 nn _ _
22 银行 _ NN NN _ 28 conj _ _

+ 16
- 13
test/io/test_dataset_loader.py View File

@@ -1,24 +1,27 @@
import unittest import unittest


from fastNLP.io.dataset_loader import Conll2003Loader
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \
ZhConllPOSReader, ConllxDataLoader




class TestDatasetLoader(unittest.TestCase): class TestDatasetLoader(unittest.TestCase):


def test_case_1(self):
'''
def test_Conll2003Loader(self):
"""
Test the the loader of Conll2003 dataset Test the the loader of Conll2003 dataset
'''

"""
dataset_path = "test/data_for_tests/conll_2003_example.txt" dataset_path = "test/data_for_tests/conll_2003_example.txt"
loader = Conll2003Loader() loader = Conll2003Loader()
dataset_2003 = loader.load(dataset_path) dataset_2003 = loader.load(dataset_path)


for item in dataset_2003:
len0 = len(item["label0_list"])
len1 = len(item["label1_list"])
len2 = len(item["label2_list"])
lentoken = len(item["token_list"])
self.assertNotEqual(len0, 0)
self.assertEqual(len0, len1)
self.assertEqual(len1, len2)
def test_PeopleDailyCorpusLoader(self):
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt")

def test_ConllCWSReader(self):
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt")

def test_ZhConllPOSReader(self):
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx")

def test_ConllxDataLoader(self):
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx")

Loading…
Cancel
Save