@@ -48,8 +48,10 @@ For example: | |||||
## Resources | ## Resources | ||||
- [Documentation](https://fastnlp.readthedocs.io/en/latest/) | - [Documentation](https://fastnlp.readthedocs.io/en/latest/) | ||||
- [Tutorials](https://github.com/fastnlp/fastNLP/tutorials) | |||||
- [Source Code](https://github.com/fastnlp/fastNLP) | - [Source Code](https://github.com/fastnlp/fastNLP) | ||||
## Installation | ## Installation | ||||
Run the following commands to install fastNLP package. | Run the following commands to install fastNLP package. | ||||
```shell | ```shell | ||||
@@ -70,7 +72,7 @@ pip install fastNLP | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | <td><b> fastNLP.core </b></td> | ||||
<td> data representation & train/test presedure </td> | |||||
<td> data representation & train/test procedure </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -0,0 +1,43 @@ | |||||
# fastNLP 高级接口 | |||||
### 环境与配置 | |||||
1. 系统环境:linux/ubuntu(推荐) | |||||
2. 编程语言:Python>=3.6 | |||||
3. Python包依赖 | |||||
- **torch==1.0** | |||||
- numpy>=1.14.2 | |||||
### 中文分词 | |||||
```python | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
from fastNLP.api import CWS | |||||
cws = CWS(device='cpu') | |||||
print(cws.predict(text)) | |||||
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | |||||
``` | |||||
### 中文分词+词性标注 | |||||
```python | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
from fastNLP.api import POS | |||||
pos = POS(device='cpu') | |||||
print(pos.predict(text)) | |||||
# [['编者/NN', '按/P', ':/PU', '7月/NT', '12日/NR', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一/OD', '款高/NN', '科技/NN', '隐形/NN', '无/VE', '人机/NN', '雷电/NN', '之/DEG', '神/NN', '。/PU'], ['这/DT', '款/NN', '飞行/VV', '从/P', '外型/NN', '上/LC', '来/MSP', '看/VV', '酷似/VV', '电影/NN', '中/LC', '的/DEG', '太空/NN', '飞行器/NN', ',/PU', '据/P', '英国/NR', '方面/NN', '介绍/VV', ',/PU', '可以/VV', '实现/VV', '洲际/NN', '远程/NN', '打击/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无/VE', '人机/NN', '到底/AD', '有/VE', '多/CD', '厉害/NN', '?/PU']] | |||||
``` | |||||
### 中文分词+词性标注+句法分析 | |||||
```python | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
from fastNLP.api import Parser | |||||
parser = Parser(device='cpu') | |||||
print(parser.predict(text)) | |||||
# [['12/nsubj', '12/prep', '2/punct', '5/nn', '2/pobj', '12/punct', '11/nn', '11/nn', '11/nn', '11/nn', '2/pobj', '0/root', '12/asp', '15/det', '16/nsubj', '21/rcmod', '16/cpm', '21/nummod', '21/nn', '21/nn', '22/top', '12/ccomp', '24/nn', '26/assmod', '24/assm', '22/dobj', '12/punct'], ['2/det', '8/xsubj', '8/mmod', '8/prep', '6/lobj', '4/plmod', '8/prtmod', '0/root', '8/ccomp', '11/lobj', '14/assmod', '11/assm', '14/nn', '9/dobj', '8/punct', '22/prep', '18/nn', '19/nsubj', '16/pccomp', '22/punct', '22/mmod', '8/dep', '25/nn', '25/nn', '22/dobj', '8/punct'], ['4/advmod', '3/det', '4/nsubj', '0/root', '4/dobj', '7/advmod', '4/conj', '9/nummod', '7/dobj', '4/punct']] | |||||
``` | |||||
完整样例见`examples.py` |
@@ -0,0 +1 @@ | |||||
from .api import CWS, POS, Parser |
@@ -7,28 +7,28 @@ import os | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.api.model_zoo import load_url | |||||
from fastNLP.api.utils import load_url | |||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader | |||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from fastNLP.api.processor import IndexerProcessor | |||||
# TODO add pretrain urls | # TODO add pretrain urls | ||||
model_urls = { | model_urls = { | ||||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", | |||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl", | |||||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" | |||||
} | } | ||||
class API: | class API: | ||||
def __init__(self): | def __init__(self): | ||||
self.pipeline = None | self.pipeline = None | ||||
self._dict = None | |||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -38,8 +38,8 @@ class API: | |||||
_dict = torch.load(path, map_location='cpu') | _dict = torch.load(path, map_location='cpu') | ||||
else: | else: | ||||
_dict = load_url(path, map_location='cpu') | _dict = load_url(path, map_location='cpu') | ||||
self.pipeline = _dict['pipeline'] | |||||
self._dict = _dict | self._dict = _dict | ||||
self.pipeline = _dict['pipeline'] | |||||
for processor in self.pipeline.pipeline: | for processor in self.pipeline.pipeline: | ||||
if isinstance(processor, ModelProcessor): | if isinstance(processor, ModelProcessor): | ||||
processor.set_model_device(device) | processor.set_model_device(device) | ||||
@@ -48,6 +48,9 @@ class API: | |||||
class POS(API): | class POS(API): | ||||
"""FastNLP API for Part-Of-Speech tagging. | """FastNLP API for Part-Of-Speech tagging. | ||||
:param str model_path: the path to the model. | |||||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||||
""" | """ | ||||
def __init__(self, model_path=None, device='cpu'): | def __init__(self, model_path=None, device='cpu'): | ||||
@@ -63,7 +66,7 @@ class POS(API): | |||||
:param content: list of list of str. Each string is a token(word). | :param content: list of list of str. Each string is a token(word). | ||||
:return answer: list of list of str. Each string is a tag. | :return answer: list of list of str. Each string is a tag. | ||||
""" | """ | ||||
if not hasattr(self, 'pipeline'): | |||||
if not hasattr(self, "pipeline"): | |||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
sentence_list = [] | sentence_list = [] | ||||
@@ -75,59 +78,72 @@ class POS(API): | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
dataset.add_field('words', sentence_list) | |||||
dataset.add_field("words", sentence_list) | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
output = dataset['word_pos_output'].content | |||||
def decode_tags(ins): | |||||
pred_tags = ins["tag"] | |||||
chars = ins["words"] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag[0] == "S": | |||||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
elif tag[0] == "E": | |||||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
return words | |||||
dataset.apply(decode_tags, new_field_name="tag_output") | |||||
output = dataset.field_arrays["tag_output"].content | |||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
return output | return output | ||||
def test(self, filepath): | |||||
tag_proc = self._dict['tag_indexer'] | |||||
model = self.pipeline.pipeline[2].model | |||||
pipeline = self.pipeline.pipeline[0:2] | |||||
pipeline.append(tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConlluPOSReader() | |||||
te_dataset = reader.load(filepath) | |||||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||||
end_tagidx_set = set() | |||||
tag_proc.vocab.build_vocab() | |||||
for key, value in tag_proc.vocab.word2idx.items(): | |||||
if key.startswith('E-'): | |||||
end_tagidx_set.add(value) | |||||
if key.startswith('S-'): | |||||
end_tagidx_set.add(value) | |||||
evaluator.end_tagidx_set = end_tagidx_set | |||||
default_valid_args = {"batch_size": 64, | |||||
"use_cuda": True, "evaluator": evaluator} | |||||
pp(te_dataset) | |||||
te_dataset.set_target(truth=True) | |||||
tester = Tester(**default_valid_args) | |||||
test_result = tester.test(model, te_dataset) | |||||
f1 = round(test_result['F'] * 100, 2) | |||||
pre = round(test_result['P'] * 100, 2) | |||||
rec = round(test_result['R'] * 100, 2) | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return f1, pre, rec | |||||
def test(self, file_path): | |||||
test_data = ZhConllPOSReader().load(file_path) | |||||
tag_vocab = self._dict["tag_vocab"] | |||||
pipeline = self._dict["pipeline"] | |||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||||
pipeline(test_data) | |||||
test_data.set_target("truth") | |||||
prediction = test_data.field_arrays["predict"].content | |||||
truth = test_data.field_arrays["truth"].content | |||||
seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||||
# padding by hand | |||||
max_length = max([len(seq) for seq in prediction]) | |||||
for idx in range(len(prediction)): | |||||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||||
seq_lens="word_seq_origin_len") | |||||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||||
{"truth": torch.Tensor(truth)}) | |||||
test_result = evaluator.get_metric() | |||||
f1 = round(test_result['f'] * 100, 2) | |||||
pre = round(test_result['pre'] * 100, 2) | |||||
rec = round(test_result['rec'] * 100, 2) | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
class CWS(API): | class CWS(API): | ||||
def __init__(self, model_path=None, device='cpu'): | def __init__(self, model_path=None, device='cpu'): | ||||
""" | |||||
中文分词高级接口。 | |||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||||
""" | |||||
super(CWS, self).__init__() | super(CWS, self).__init__() | ||||
if model_path is None: | if model_path is None: | ||||
model_path = model_urls['cws'] | model_path = model_urls['cws'] | ||||
@@ -135,7 +151,13 @@ class CWS(API): | |||||
self.load(model_path, device) | self.load(model_path, device) | ||||
def predict(self, content): | def predict(self, content): | ||||
""" | |||||
分词接口。 | |||||
:param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如 | |||||
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。 | |||||
:return: str或List[str], 根据输入的的类型决定。 | |||||
""" | |||||
if not hasattr(self, 'pipeline'): | if not hasattr(self, 'pipeline'): | ||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
@@ -153,33 +175,55 @@ class CWS(API): | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
output = dataset['output'].content | |||||
output = dataset.get_field('output').content | |||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
return output | return output | ||||
def test(self, filepath): | def test(self, filepath): | ||||
tag_proc = self._dict['tag_indexer'] | |||||
""" | |||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | |||||
分词文件应该为: | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
以空行分割两个句子,有内容的每行有7列。 | |||||
:param filepath: str, 文件路径路径。 | |||||
:return: float, float, float. 分别f1, precision, recall. | |||||
""" | |||||
tag_proc = self._dict['tag_proc'] | |||||
cws_model = self.pipeline.pipeline[-2].model | cws_model = self.pipeline.pipeline[-2].model | ||||
pipeline = self.pipeline.pipeline[:5] | |||||
pipeline = self.pipeline.pipeline[:-2] | |||||
pipeline.insert(1, tag_proc) | pipeline.insert(1, tag_proc) | ||||
pp = Pipeline(pipeline) | pp = Pipeline(pipeline) | ||||
reader = ConlluCWSReader() | |||||
reader = ConllCWSReader() | |||||
# te_filename = '/home/hyan/ctb3/test.conllx' | # te_filename = '/home/hyan/ctb3/test.conllx' | ||||
te_dataset = reader.load(filepath) | te_dataset = reader.load(filepath) | ||||
pp(te_dataset) | pp(te_dataset) | ||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') | |||||
f1 = round(f1 * 100, 2) | |||||
pre = round(pre * 100, 2) | |||||
rec = round(rec * 100, 2) | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||||
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64, | |||||
verbose=0) | |||||
eval_res = tester.test() | |||||
f1 = eval_res['BMESF1PreRecMetric']['f'] | |||||
pre = eval_res['BMESF1PreRecMetric']['pre'] | |||||
rec = eval_res['BMESF1PreRecMetric']['rec'] | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | ||||
return f1, pre, rec | return f1, pre, rec | ||||
@@ -191,30 +235,30 @@ class Parser(API): | |||||
if model_path is None: | if model_path is None: | ||||
model_path = model_urls['parser'] | model_path = model_urls['parser'] | ||||
self.pos_tagger = POS(device=device) | |||||
self.load(model_path, device) | self.load(model_path, device) | ||||
def predict(self, content): | def predict(self, content): | ||||
if not hasattr(self, 'pipeline'): | if not hasattr(self, 'pipeline'): | ||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 1. 利用POS得到分词和pos tagging结果 | |||||
pos_out = self.pos_tagger.predict(content) | |||||
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
dataset.add_field('words', sentence_list) | |||||
# dataset.add_field('tag', sentence_list) | |||||
dataset.add_field('wp', pos_out) | |||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | |||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
for ins in dataset: | |||||
ins['heads'] = ins['heads'].tolist() | |||||
return dataset['heads'], dataset['labels'] | |||||
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') | |||||
dataset.apply(lambda x: [arc + '/' + label for arc, label in | |||||
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') | |||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | |||||
return dataset.field_arrays['output'].content | |||||
def test(self, filepath): | def test(self, filepath): | ||||
data = ConllxDataLoader().load(filepath) | data = ConllxDataLoader().load(filepath) | ||||
@@ -276,7 +320,7 @@ class Analyzer: | |||||
def test(self, filepath): | def test(self, filepath): | ||||
output_dict = {} | output_dict = {} | ||||
if self.seg: | |||||
if self.cws: | |||||
seg_output = self.cws.test(filepath) | seg_output = self.cws.test(filepath) | ||||
output_dict['seg'] = seg_output | output_dict['seg'] = seg_output | ||||
if self.pos: | if self.pos: | ||||
@@ -287,28 +331,3 @@ class Analyzer: | |||||
output_dict['parser'] = parser_output | output_dict['parser'] = parser_output | ||||
return output_dict | return output_dict | ||||
if __name__ == "__main__": | |||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||||
# pos = POS(device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||||
# print(pos.predict(s)) | |||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||||
# cws = CWS(device='cpu') | |||||
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | |||||
# print(cws.predict(s)) | |||||
parser = Parser(device='cpu') | |||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
print(parser.predict(s)) |
@@ -0,0 +1,29 @@ | |||||
""" | |||||
api/example.py contains all API examples provided by fastNLP. | |||||
It is used as a tutorial for API or a test script since it is difficult to test APIs in travis. | |||||
""" | |||||
from fastNLP.api import CWS, POS, Parser | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
def chinese_word_segmentation(): | |||||
cws = CWS(device='cpu') | |||||
print(cws.predict(text)) | |||||
def pos_tagging(): | |||||
pos = POS(device='cpu') | |||||
print(pos.predict(text)) | |||||
def syntactic_parsing(): | |||||
parser = Parser(device='cpu') | |||||
print(parser.predict(text)) | |||||
if __name__ == "__main__": | |||||
syntactic_parsing() |
@@ -11,6 +11,11 @@ from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor(object): | class Processor(object): | ||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
""" | |||||
:param field_name: 处理哪个field | |||||
:param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field | |||||
""" | |||||
self.field_name = field_name | self.field_name = field_name | ||||
if new_added_field_name is None: | if new_added_field_name is None: | ||||
self.new_added_field_name = field_name | self.new_added_field_name = field_name | ||||
@@ -92,6 +97,11 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
class PreAppendProcessor(Processor): | class PreAppendProcessor(Processor): | ||||
""" | |||||
向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 | |||||
[data] + instance[field_name] | |||||
""" | |||||
def __init__(self, data, field_name, new_added_field_name=None): | def __init__(self, data, field_name, new_added_field_name=None): | ||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.data = data | self.data = data | ||||
@@ -102,6 +112,10 @@ class PreAppendProcessor(Processor): | |||||
class SliceProcessor(Processor): | class SliceProcessor(Processor): | ||||
""" | |||||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||||
""" | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | def __init__(self, start, end, step, field_name, new_added_field_name=None): | ||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | super(SliceProcessor, self).__init__(field_name, new_added_field_name) | ||||
for o in (start, end, step): | for o in (start, end, step): | ||||
@@ -114,7 +128,17 @@ class SliceProcessor(Processor): | |||||
class Num2TagProcessor(Processor): | class Num2TagProcessor(Processor): | ||||
""" | |||||
将一句话中的数字转换为某个tag。 | |||||
""" | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | def __init__(self, tag, field_name, new_added_field_name=None): | ||||
""" | |||||
:param tag: str, 将数字转换为该tag | |||||
:param field_name: | |||||
:param new_added_field_name: | |||||
""" | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.tag = tag | self.tag = tag | ||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | ||||
@@ -135,6 +159,10 @@ class Num2TagProcessor(Processor): | |||||
class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
""" | |||||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||||
['我', '是', xxx] | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -163,19 +191,19 @@ class IndexerProcessor(Processor): | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
"""Build vocabulary with a field in the data set. | |||||
""" | |||||
传入若干个DataSet以建立vocabulary。 | |||||
""" | """ | ||||
def __init__(self, field_name): | |||||
def __init__(self, field_name, min_freq=1, max_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary() | |||||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||||
def process(self, *datasets): | def process(self, *datasets): | ||||
for dataset in datasets: | for dataset in datasets: | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
self.vocab.update(ins[self.field_name]) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
@@ -183,6 +211,10 @@ class VocabProcessor(Processor): | |||||
class SeqLenProcessor(Processor): | class SeqLenProcessor(Processor): | ||||
""" | |||||
根据某个field新增一个sequence length的field。取该field的第一维 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | ||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.is_input = is_input | self.is_input = is_input | ||||
@@ -195,10 +227,15 @@ class SeqLenProcessor(Processor): | |||||
return dataset | return dataset | ||||
from fastNLP.core.utils import _build_args | |||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
迭代模型并将结果的padding drop掉 | |||||
传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. | |||||
model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 | |||||
(Batch_size, 1),则使用seqence length这个field进行unpad | |||||
TODO 这个类需要删除对seq_lens的依赖。 | |||||
:param seq_len_field_name: | :param seq_len_field_name: | ||||
:param batch_size: | :param batch_size: | ||||
@@ -211,13 +248,18 @@ class ModelProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
self.model.eval() | self.model.eval() | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | |||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
if hasattr(self.model, "predict"): | |||||
predict_func = self.model.predict | |||||
else: | |||||
predict_func = self.model.forward | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
prediction = self.model.predict(**batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].cpu().numpy().tolist() | |||||
refined_batch_x = _build_args(predict_func, **batch_x) | |||||
prediction = predict_func(**refined_batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].tolist() | |||||
for key, value in prediction.items(): | for key, value in prediction.items(): | ||||
tmp_batch = [] | tmp_batch = [] | ||||
@@ -228,8 +270,8 @@ class ModelProcessor(Processor): | |||||
for idx, seq_len in enumerate(seq_lens): | for idx, seq_len in enumerate(seq_lens): | ||||
tmp_batch.append(value[idx, :seq_len]) | tmp_batch.append(value[idx, :seq_len]) | ||||
batch_output[key].extend(tmp_batch) | batch_output[key].extend(tmp_batch) | ||||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||||
if not self.seq_len_field_name in prediction: | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||||
# TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | ||||
for field_name, fields in batch_output.items(): | for field_name, fields in batch_output.items(): | ||||
@@ -246,6 +288,10 @@ class ModelProcessor(Processor): | |||||
class Index2WordProcessor(Processor): | class Index2WordProcessor(Processor): | ||||
""" | |||||
将DataSet中某个为index的field根据vocab转换为str | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name): | def __init__(self, vocab, field_name, new_added_field_name): | ||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
@@ -256,15 +302,23 @@ class Index2WordProcessor(Processor): | |||||
return dataset | return dataset | ||||
class SetIsTargetProcessor(Processor): | |||||
class SetTargetProcessor(Processor): | |||||
# TODO; remove it. | # TODO; remove it. | ||||
def __init__(self, field_dict, default=False): | |||||
super(SetIsTargetProcessor, self).__init__(None, None) | |||||
self.field_dict = field_dict | |||||
self.default = default | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetTargetProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | |||||
dataset.set_target(*self.fields, flag=self.flag) | |||||
return dataset | |||||
class SetInputProcessor(Processor): | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetInputProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | def process(self, dataset): | ||||
set_dict = {name: self.default for name in dataset.get_all_fields().keys()} | |||||
set_dict.update(self.field_dict) | |||||
dataset.set_target(**set_dict) | |||||
dataset.set_input(*self.fields, flag=self.flag) | |||||
return dataset | return dataset |
@@ -1,6 +1,8 @@ | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.sampler import RandomSampler | |||||
class Batch(object): | class Batch(object): | ||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
@@ -17,14 +19,15 @@ class Batch(object): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
self.curidx = 0 | self.curidx = 0 | ||||
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) | |||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | def __iter__(self): | ||||
self.idx_list = self.sampler(self.dataset) | self.idx_list = self.sampler(self.dataset) | ||||
@@ -40,6 +43,7 @@ class Batch(object): | |||||
batch_x, batch_y = {}, {} | batch_x, batch_y = {}, {} | ||||
indices = self.idx_list[self.curidx:endidx] | indices = self.idx_list[self.curidx:endidx] | ||||
self.cur_batch_indices = indices | |||||
for field_name, field in self.dataset.get_all_fields().items(): | for field_name, field in self.dataset.get_all_fields().items(): | ||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
@@ -58,6 +62,9 @@ class Batch(object): | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_batches | return self.num_batches | ||||
def get_batch_indices(self): | |||||
return self.cur_batch_indices | |||||
def to_tensor(batch, dtype): | def to_tensor(batch, dtype): | ||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | if dtype in (int, np.int8, np.int16, np.int32, np.int64): | ||||
@@ -0,0 +1,242 @@ | |||||
class Callback(object): | |||||
"""An Interface for all callbacks. | |||||
Any customized callback should implement at least one of the following methods. | |||||
""" | |||||
def __init__(self): | |||||
super(Callback, self).__init__() | |||||
def before_train(self): | |||||
# before the main training loop | |||||
pass | |||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
# at the beginning of each epoch | |||||
pass | |||||
def before_batch(self, batch_x, batch_y, indices): | |||||
# at the beginning of each step/mini-batch | |||||
pass | |||||
def before_loss(self, batch_y, predict_y): | |||||
# after data_forward, and before loss computation | |||||
pass | |||||
def before_backward(self, loss, model): | |||||
# after loss computation, and before gradient backward | |||||
pass | |||||
def after_backward(self, model): | |||||
pass | |||||
def after_step(self, optimizer): | |||||
pass | |||||
def after_batch(self, *args): | |||||
# at the end of each step/mini-batch | |||||
pass | |||||
def after_valid(self, eval_result, metric_key, optimizer): | |||||
""" | |||||
每次执行验证机的evaluation后会调用。传入eval_result | |||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | |||||
:param metric_key: str | |||||
:param optimizer: | |||||
:return: | |||||
""" | |||||
pass | |||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
""" | |||||
每个epoch结束将会调用该方法 | |||||
:param cur_epoch: int, 当前的batch。从1开始。 | |||||
:param n_epoch: int, 总的batch数 | |||||
:param optimizer: 传入Trainer的optimizer。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def after_train(self, model): | |||||
""" | |||||
训练结束,调用该方法 | |||||
:param model: nn.Module, 传入Trainer的模型 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_exception(self, exception, model, indices): | |||||
""" | |||||
当训练过程出现异常,会触发该方法 | |||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | |||||
:param model: 传入Trainer的模型 | |||||
:param indices: 当前batch的index | |||||
:return: | |||||
""" | |||||
pass | |||||
def transfer(func): | |||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||||
:param func: | |||||
:return: | |||||
""" | |||||
def wrapper(manager, *arg): | |||||
returns = [] | |||||
for callback in manager.callbacks: | |||||
for env_name, env_value in manager.env.items(): | |||||
setattr(callback, env_name, env_value) | |||||
returns.append(getattr(callback, func.__name__)(*arg)) | |||||
return returns | |||||
return wrapper | |||||
class CallbackManager(Callback): | |||||
"""A manager for all callbacks passed into Trainer. | |||||
It collects resources inside Trainer and raise callbacks. | |||||
""" | |||||
def __init__(self, env, callbacks=None): | |||||
""" | |||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||||
:param Callback callbacks: | |||||
""" | |||||
super(CallbackManager, self).__init__() | |||||
# set attribute of trainer environment | |||||
self.env = env | |||||
self.callbacks = [] | |||||
if callbacks is not None: | |||||
if isinstance(callbacks, list): | |||||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||||
self.callbacks.extend(callbacks) | |||||
else: | |||||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||||
else: | |||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||||
@transfer | |||||
def before_train(self): | |||||
pass | |||||
@transfer | |||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
pass | |||||
@transfer | |||||
def before_batch(self, batch_x, batch_y, indices): | |||||
pass | |||||
@transfer | |||||
def before_loss(self, batch_y, predict_y): | |||||
pass | |||||
@transfer | |||||
def before_backward(self, loss, model): | |||||
pass | |||||
@transfer | |||||
def after_backward(self, model): | |||||
pass | |||||
@transfer | |||||
def after_step(self, optimizer): | |||||
pass | |||||
@transfer | |||||
def after_batch(self): | |||||
pass | |||||
@transfer | |||||
def after_valid(self, eval_result, metric_key, optimizer): | |||||
pass | |||||
@transfer | |||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
pass | |||||
@transfer | |||||
def after_train(self, model): | |||||
pass | |||||
@transfer | |||||
def on_exception(self, exception, model, indices): | |||||
pass | |||||
class DummyCallback(Callback): | |||||
def before_train(self, *arg): | |||||
print(arg) | |||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
print(cur_epoch, n_epoch, optimizer) | |||||
class EchoCallback(Callback): | |||||
def before_train(self): | |||||
print("before_train") | |||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
print("before_epoch") | |||||
def before_batch(self, batch_x, batch_y, indices): | |||||
print("before_batch") | |||||
def before_loss(self, batch_y, predict_y): | |||||
print("before_loss") | |||||
def before_backward(self, loss, model): | |||||
print("before_backward") | |||||
def after_batch(self): | |||||
print("after_batch") | |||||
def after_epoch(self, cur_epoch, n_epoch, optimizer): | |||||
print("after_epoch") | |||||
def after_train(self, model): | |||||
print("after_train") | |||||
class GradientClipCallback(Callback): | |||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||||
""" | |||||
每次backward前,将parameter的gradient clip到某个范围。 | |||||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | |||||
的model中所有参数进行clip | |||||
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||||
:param clip_type: str, 支持'norm', 'value'两种。 | |||||
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于 | |||||
clip_value的gradient被赋值为clip_value. | |||||
""" | |||||
super().__init__() | |||||
from torch import nn | |||||
if clip_type == 'norm': | |||||
self.clip_fun = nn.utils.clip_grad_norm_ | |||||
elif clip_type == 'value': | |||||
self.clip_fun = nn.utils.clip_grad_value_ | |||||
else: | |||||
raise ValueError("Only supports `norm` or `value` right now.") | |||||
self.parameters = parameters | |||||
self.clip_value = clip_value | |||||
def after_backward(self, model): | |||||
self.clip_fun(model.parameters(), self.clip_value) | |||||
if __name__ == "__main__": | |||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||||
manager.before_train(10, 11, 12) | |||||
# print(manager.after_epoch()) |
@@ -254,6 +254,8 @@ class DataSet(object): | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
""" | """ | ||||
results = [func(ins) for ins in self._inner_iter()] | results = [func(ins) for ins in self._inner_iter()] | ||||
if len(list(filter(lambda x: x is not None, results))) == 0 and not (new_field_name is None): # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
extra_param = {} | extra_param = {} | ||||
if 'is_input' in kwargs: | if 'is_input' in kwargs: | ||||
@@ -261,8 +263,6 @@ class DataSet(object): | |||||
if 'is_target' in kwargs: | if 'is_target' in kwargs: | ||||
extra_param['is_target'] = kwargs['is_target'] | extra_param['is_target'] = kwargs['is_target'] | ||||
if new_field_name is not None: | if new_field_name is not None: | ||||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
if new_field_name in self.field_arrays: | if new_field_name in self.field_arrays: | ||||
# overwrite the field, keep same attributes | # overwrite the field, keep same attributes | ||||
old_field = self.field_arrays[new_field_name] | old_field = self.field_arrays[new_field_name] | ||||
@@ -30,5 +30,7 @@ class Instance(object): | |||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
s = '\'' | |||||
return "{" + ",\n".join( | return "{" + ",\n".join( | ||||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) +\ | |||||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" |
@@ -195,7 +195,7 @@ class CrossEntropyLoss(LossBase): | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx) | |||||
ignore_index=self.padding_idx) | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
@@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | ||||
return loss | return loss | ||||
@@ -10,6 +10,7 @@ from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class MetricBase(object): | class MetricBase(object): | ||||
@@ -80,11 +81,6 @@ class MetricBase(object): | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | |||||
# if func_spect.varargs: | |||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
# f"positional argument.).") | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
@@ -108,10 +104,9 @@ class MetricBase(object): | |||||
This method will call self.evaluate method. | This method will call self.evaluate method. | ||||
Before calling self.evaluate, it will first check the validity of output_dict, target_dict | Before calling self.evaluate, it will first check the validity of output_dict, target_dict | ||||
(1) whether self.evaluate has varargs, which is not supported. | |||||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
(1) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||||
(2) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||||
(3) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
@@ -299,6 +294,368 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
return evaluate_result | return evaluate_result | ||||
def bmes_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bmes_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bmes_tag = bmes_tag | |||||
return [(span[0], (span[1][0], span[1][1])) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
def bio_tag_to_spans(tags, ignore_labels=None): | |||||
""" | |||||
:param tags: List[str], | |||||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||||
""" | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||||
spans = [] | |||||
prev_bio_tag = None | |||||
for idx, tag in enumerate(tags): | |||||
tag = tag.lower() | |||||
bio_tag, label = tag[:1], tag[2:] | |||||
if bio_tag == 'b': | |||||
spans.append((label, [idx, idx])) | |||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||||
spans[-1][1][1] = idx | |||||
elif bio_tag == 'o': # o tag does not count | |||||
pass | |||||
else: | |||||
spans.append((label, [idx, idx])) | |||||
prev_bio_tag = bio_tag | |||||
return [(span[0], (span[1][0], span[1][1])) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | |||||
class SpanFPreRecMetric(MetricBase): | |||||
""" | |||||
在序列标注问题中,以span的方式计算F, pre, rec. | |||||
最后得到的metric结果为 | |||||
{ | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值 | |||||
{ | |||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
""" | |||||
def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, | |||||
only_gross=True, f_type='micro', beta=1): | |||||
""" | |||||
:param tag_vocab: Vocabulary, 标签的vocabulary。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。 | |||||
:param encoding_type: str, 目前支持bio, bmes | |||||
:param ignore_labels, List[str]. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||||
个label | |||||
:param only_gross, bool. 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||||
label的f1, pre, rec | |||||
:param f_type, str. 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param beta, float. f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | |||||
encoding_type = encoding_type.lower() | |||||
if encoding_type not in ('bio', 'bmes'): | |||||
raise ValueError("Only support 'bio' or 'bmes' type.") | |||||
if not isinstance(tag_vocab, Vocabulary): | |||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||||
if f_type not in ('micro', 'macro'): | |||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||||
self.encoding_type = encoding_type | |||||
if self.encoding_type == 'bmes': | |||||
self.tag_to_span_func = bmes_tag_to_spans | |||||
elif self.encoding_type == 'bio': | |||||
self.tag_to_span_func = bio_tag_to_spans | |||||
self.ignore_labels = ignore_labels | |||||
self.f_type = f_type | |||||
self.beta = beta | |||||
self.beta_square = self.beta**2 | |||||
self.only_gross = only_gross | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.tag_vocab = tag_vocab | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
def evaluate(self, pred, target, seq_lens): | |||||
""" | |||||
A lot of design idea comes from allennlp's measure | |||||
:param pred: | |||||
:param target: | |||||
:param seq_lens: | |||||
:return: | |||||
""" | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
pred = pred.argmax(dim=-1) | |||||
num_classes = pred.size(-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
batch_size = pred.size(0) | |||||
for i in range(batch_size): | |||||
pred_tags = pred[i, :int(seq_lens[i])].tolist() | |||||
gold_tags = target[i, :int(seq_lens[i])].tolist() | |||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||||
for span in pred_spans: | |||||
if span in gold_spans: | |||||
self._true_positives[span[0]] += 1 | |||||
gold_spans.remove(span) | |||||
else: | |||||
self._false_positives[span[0]] += 1 | |||||
for span in gold_spans: | |||||
self._false_negatives[span[0]] += 1 | |||||
def get_metric(self, reset=True): | |||||
evaluate_result = {} | |||||
if not self.only_gross or self.f_type=='macro': | |||||
tags = set(self._false_negatives.keys()) | |||||
tags.update(set(self._false_positives.keys())) | |||||
tags.update(set(self._true_positives.keys())) | |||||
f_sum = 0 | |||||
pre_sum = 0 | |||||
rec_sum = 0 | |||||
for tag in tags: | |||||
tp = self._true_positives[tag] | |||||
fn = self._false_negatives[tag] | |||||
fp = self._false_positives[tag] | |||||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||||
f_sum += f | |||||
pre_sum += pre | |||||
rec_sum + rec | |||||
if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | |||||
pre_key = 'pre-{}'.format(tag) | |||||
rec_key = 'rec-{}'.format(tag) | |||||
evaluate_result[f_key] = f | |||||
evaluate_result[pre_key] = pre | |||||
evaluate_result[rec_key] = rec | |||||
if self.f_type == 'macro': | |||||
evaluate_result['f'] = f_sum/len(tags) | |||||
evaluate_result['pre'] = pre_sum/len(tags) | |||||
evaluate_result['rec'] = rec_sum/len(tags) | |||||
if self.f_type == 'micro': | |||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||||
sum(self._false_negatives.values()), | |||||
sum(self._false_positives.values())) | |||||
evaluate_result['f'] = round(f, 6) | |||||
evaluate_result['pre'] = round(pre, 6) | |||||
evaluate_result['rec'] = round(rec, 6) | |||||
if reset: | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
return evaluate_result | |||||
def _compute_f_pre_rec(self, tp, fn, fp): | |||||
""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
class BMESF1PreRecMetric(MetricBase): | |||||
""" | |||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||||
| | next_B | next_M | next_E | next_S | end | | |||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
举例: | |||||
prediction为BSEMS,会被认为是SSSSS. | |||||
本Metric不检验target的合法性,请务必保证target的合法性。 | |||||
pred的形状应该为(batch_size, max_len) 或 (batch_size, max_len, 4)。 | |||||
target形状为 (batch_size, max_len) | |||||
seq_lens形状为 (batch_size, ) | |||||
""" | |||||
def __init__(self, b_idx=0, m_idx=1, e_idx=2, s_idx=3, pred=None, target=None, seq_lens=None): | |||||
""" | |||||
需要申明BMES这四种tag中,各种tag对应的idx。所有不为b_idx, m_idx, e_idx, s_idx的数字都认为是s_idx。 | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
:param pred: str, 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||||
:param target: str, 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||||
:param seq_lens: str, 用该key在evaluate()时从传入dict中取出seqence length数据。为None,则使用'seq_lens'取数据。 | |||||
""" | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.yt_wordnum = 0 | |||||
self.yp_wordnum = 0 | |||||
self.corr_num = 0 | |||||
self.b_idx = b_idx | |||||
self.m_idx = m_idx | |||||
self.e_idx = e_idx | |||||
self.s_idx = s_idx | |||||
# 还原init处介绍的矩阵 | |||||
self._valida_matrix = { | |||||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
} | |||||
def _validate_tags(self, tags): | |||||
""" | |||||
给定一个tag的Tensor,返回合法tag | |||||
:param tags: Tensor, shape: (seq_len, ) | |||||
:return: 返回修改为合法tag的list | |||||
""" | |||||
assert len(tags)!=0 | |||||
assert isinstance(tags, torch.Tensor) and len(tags.size())==1 | |||||
padded_tags = [-1, *tags.tolist(), -1] | |||||
for idx in range(len(padded_tags)-1): | |||||
cur_tag = padded_tags[idx] | |||||
if cur_tag not in self._valida_matrix: | |||||
cur_tag = self.s_idx | |||||
if padded_tags[idx+1] not in self._valida_matrix: | |||||
padded_tags[idx+1] = self.s_idx | |||||
next_tag = padded_tags[idx+1] | |||||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||||
if shift_tag[0]!=-1: | |||||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||||
return padded_tags[1:-1] | |||||
def evaluate(self, pred, target, seq_lens): | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_lens)}.") | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
pred = pred.argmax(dim=-1) | |||||
else: | |||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
for idx in range(len(pred)): | |||||
seq_len = seq_lens[idx] | |||||
target_tags = target[idx][:seq_len].tolist() | |||||
pred_tags = pred[idx][:seq_len] | |||||
pred_tags = self._validate_tags(pred_tags) | |||||
start_idx = 0 | |||||
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)): | |||||
if t_tag in (self.s_idx, self.e_idx): | |||||
self.yt_wordnum += 1 | |||||
corr_flag = True | |||||
for i in range(start_idx, t_idx+1): | |||||
if target_tags[i]!=pred_tags[i]: | |||||
corr_flag = False | |||||
if corr_flag: | |||||
self.corr_num += 1 | |||||
start_idx = t_idx + 1 | |||||
if p_tag in (self.s_idx, self.e_idx): | |||||
self.yp_wordnum += 1 | |||||
def get_metric(self, reset=True): | |||||
P = self.corr_num / (self.yp_wordnum + 1e-12) | |||||
R = self.corr_num / (self.yt_wordnum + 1e-12) | |||||
F = 2 * P * R / (P + R + 1e-12) | |||||
evaluate_result = {'f': round(F, 6), 'pre':round(P, 6), 'rec': round(R, 6)} | |||||
if reset: | |||||
self.yp_wordnum = 0 | |||||
self.yt_wordnum = 0 | |||||
self.corr_num = 0 | |||||
return evaluate_result | |||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -7,9 +7,14 @@ import numpy as np | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
from torch import nn | from torch import nn | ||||
from tqdm.autonotebook import tqdm | |||||
try: | |||||
from tqdm.autonotebook import tqdm | |||||
except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
@@ -27,7 +32,11 @@ from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
""" | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), | |||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), use_tqdm=True, use_cuda=False, | |||||
callbacks=None): | |||||
""" | |||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
:param LossBase loss: a loss object | :param LossBase loss: a loss object | ||||
@@ -48,16 +57,10 @@ class Trainer(object): | |||||
smaller, add "-" in front of the string. For example:: | smaller, add "-" in front of the string. For example:: | ||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | metric_key="-PPL" # language model gets better as perplexity gets smaller | ||||
:param BaseSampler sampler: method used to generate batch data. | :param BaseSampler sampler: method used to generate batch data. | ||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
""" | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||||
""" | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | if not isinstance(train_data, DataSet): | ||||
@@ -109,9 +112,10 @@ class Trainer(object): | |||||
self.use_cuda = bool(use_cuda) | self.use_cuda = bool(use_cuda) | ||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) | |||||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.sampler = sampler | self.sampler = sampler | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -119,11 +123,7 @@ class Trainer(object): | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
if self.use_tqdm: | |||||
tester_verbose = 0 | |||||
self.print_every = abs(self.print_every) | |||||
else: | |||||
tester_verbose = 1 | |||||
self.print_every = abs(self.print_every) | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
self.tester = Tester(model=self.model, | self.tester = Tester(model=self.model, | ||||
@@ -131,7 +131,7 @@ class Trainer(object): | |||||
metrics=self.metrics, | metrics=self.metrics, | ||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
use_cuda=self.use_cuda, | use_cuda=self.use_cuda, | ||||
verbose=tester_verbose) | |||||
verbose=0) | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
@@ -141,20 +141,30 @@ class Trainer(object): | |||||
开始训练过程。主要有以下几个步骤:: | 开始训练过程。主要有以下几个步骤:: | ||||
对于每次循环 | |||||
1. 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。 | |||||
for epoch in range(num_epochs): | |||||
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。 | |||||
非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | 非float,int类型的参数将不会被转换为Tensor,且不进行padding。 | ||||
for batch_x, batch_y in Batch(DataSet) | for batch_x, batch_y in Batch(DataSet) | ||||
# batch_x中为设置为input的field | |||||
# batch_y中为设置为target的field | |||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果 | |||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss | |||||
# batch_x是一个dict, 被设为input的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
# batch_y也是一个dict,被设为target的field会出现在这个dict中, | |||||
key为DataSet中的field_name, value为该field的value | |||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形 | |||||
参完成参数传递。例如, | |||||
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的 | |||||
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数 | |||||
而且默认参数这个key没有在batch_x中,则使用默认参数。 | |||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况 | |||||
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或 | |||||
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计 | |||||
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target | |||||
, 并完成loss的计算。 | |||||
4. 获取到loss之后,进行反向求导并更新梯度 | 4. 获取到loss之后,进行反向求导并更新梯度 | ||||
如果测试集不为空 | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
根据需要适时进行验证机测试 | |||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的 | |||||
模型参数。 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | :return results: 返回一个字典类型的数据, 内含以下内容:: | ||||
seconds: float, 表示训练时长 | seconds: float, 表示训练时长 | ||||
@@ -187,10 +197,11 @@ class Trainer(object): | |||||
else: | else: | ||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | ||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
if self.use_tqdm: | |||||
self._tqdm_train() | |||||
else: | |||||
self._print_train() | |||||
self.callback_manager.before_train() | |||||
self._train() | |||||
self.callback_manager.after_train(self.model) | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
self.tester._format_eval_results(self.best_dev_perf),) | self.tester._format_eval_results(self.best_dev_perf),) | ||||
@@ -199,8 +210,11 @@ class Trainer(object): | |||||
results['best_step'] = self.best_dev_step | results['best_step'] = self.best_dev_step | ||||
if load_best_model: | if load_best_model: | ||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | ||||
# self._load_model(self.model, model_name) | |||||
print("Reloaded the best model.") | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | finally: | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
@@ -208,22 +222,43 @@ class Trainer(object): | |||||
return results | return results | ||||
def _tqdm_train(self): | |||||
def _train(self): | |||||
if not self.use_tqdm: | |||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | self.step = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
total_steps = data_iterator.num_batches*self.n_epochs | |||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
start = time.time() | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||||
total_steps = data_iterator.num_batches * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
avg_loss = 0 | avg_loss = 0 | ||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | |||||
self.callback_manager.before_epoch(epoch, self.n_epochs) | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
indices = data_iterator.get_batch_indices() | |||||
# negative sampling; replace unknown; re-weight batch_y | |||||
self.callback_manager.before_batch(batch_x, batch_y, indices) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | |||||
self.callback_manager.before_loss(batch_y, prediction) | |||||
loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
# Is loss NaN or inf? requires_grad = False | |||||
self.callback_manager.before_backward(loss, self.model) | |||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
# gradient clipping | |||||
self.callback_manager.after_backward(self.model) | |||||
self._update() | self._update() | ||||
# lr scheduler; lr_finder; one_cycle | |||||
self.callback_manager.after_step(self.optimizer) | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | ||||
for name, param in self.model.named_parameters(): | for name, param in self.model.named_parameters(): | ||||
if param.requires_grad: | if param.requires_grad: | ||||
@@ -231,65 +266,41 @@ class Trainer(object): | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | ||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | ||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every)) | |||||
if self.use_tqdm: | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
pbar.update(self.print_every) | |||||
else: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, avg_loss, diff) | |||||
pbar.set_postfix_str(print_output) | |||||
avg_loss = 0 | avg_loss = 0 | ||||
pbar.update(self.print_every) | |||||
self.step += 1 | self.step += 1 | ||||
if self.validate_every > 0 and self.step % self.validate_every == 0 \ | |||||
# do nothing | |||||
self.callback_manager.after_batch() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \ | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str) | pbar.write(eval_str) | ||||
if self.validate_every < 0 and self.dev_data: | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
if epoch!=self.n_epochs: | |||||
# if self.validate_every < 0 and self.dev_data: | |||||
# eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||||
# self.tester._format_eval_results(eval_res) | |||||
# pbar.write(eval_str) | |||||
if epoch != self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
as_numpy=False) | as_numpy=False) | ||||
# lr decay; early stopping | |||||
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) | |||||
pbar.close() | pbar.close() | ||||
def _print_train(self): | |||||
epoch = 1 | |||||
start = time.time() | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
for batch_x, batch_y in data_iterator: | |||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0 and | |||||
self.dev_data is not None): | |||||
self._do_validation(epoch=epoch, step=self.step) | |||||
self.step += 1 | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self._do_validation(epoch=epoch, step=self.step) | |||||
epoch += 1 | |||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, metric in res.items(): | for name, metric in res.items(): | ||||
@@ -300,10 +311,13 @@ class Trainer(object): | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | ||||
else: | |||||
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} | |||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
# get validation results; adjust optimizer | |||||
self.callback_manager.after_valid(res, self.metric_key, self.optimizer) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -359,7 +373,7 @@ class Trainer(object): | |||||
torch.save(model, model_name) | torch.save(model, model_name) | ||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# TODO: 这个是不是有问题? | |||||
# 返回bool值指示是否成功reload模型 | |||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_path = os.path.join(self.save_path, model_name) | model_path = os.path.join(self.save_path, model_name) | ||||
if only_param: | if only_param: | ||||
@@ -367,6 +381,11 @@ class Trainer(object): | |||||
else: | else: | ||||
states = torch.load(model_path).state_dict() | states = torch.load(model_path).state_dict() | ||||
model.load_state_dict(states) | model.load_state_dict(states) | ||||
elif hasattr(self, "_best_model_states"): | |||||
model.load_state_dict(self._best_model_states) | |||||
else: | |||||
return False | |||||
return True | |||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -472,7 +491,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
evaluate_results = tester.test() | evaluate_results = tester.test() | ||||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | ||||
@@ -9,7 +9,7 @@ import numpy as np | |||||
import torch | import torch | ||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs'], verbose=False) | |||||
'varargs']) | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
@@ -400,7 +400,7 @@ def seq_lens_to_masks(seq_lens, float=False): | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | ||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | ||||
raise NotImplemented | raise NotImplemented | ||||
elif isinstance(seq_lens, torch.LongTensor): | |||||
elif isinstance(seq_lens, torch.Tensor): | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | ||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
@@ -430,3 +430,30 @@ def seq_mask(seq_len, max_len): | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | ||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | ||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | return torch.gt(seq_len, seq_range) # [batch_size, max_len] | ||||
class pseudo_tqdm: | |||||
""" | |||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||||
""" | |||||
def __init__(self, **kwargs): | |||||
pass | |||||
def write(self, info): | |||||
print(info) | |||||
def set_postfix_str(self, info): | |||||
print(info) | |||||
def __getattr__(self, item): | |||||
def pass_func(*args, **kwargs): | |||||
pass | |||||
return pass_func | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
del self |
@@ -270,8 +270,8 @@ class ClassDataSetLoader(DataSetLoader): | |||||
def parse(lines): | def parse(lines): | ||||
""" | """ | ||||
:param list lines: lines from dataset | |||||
:return: a 3-D list, indicating words, sentence, and dataset respectively. | |||||
:param lines: lines from dataset | |||||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | |||||
""" | """ | ||||
dataset = list() | dataset = list() | ||||
for line in lines: | for line in lines: | ||||
@@ -96,7 +96,7 @@ class EmbedLoader(BaseLoader): | |||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | :param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | ||||
:param str emb_file: the pre-trained embedding file path. | :param str emb_file: the pre-trained embedding file path. | ||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | :param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | ||||
:return: the embedding matrix, numpy.ndarray | |||||
:return embedding_matrix: numpy.ndarray | |||||
""" | """ | ||||
if vocab is None: | if vocab is None: | ||||
@@ -3,4 +3,4 @@ from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .char_language_model import CharLM | from .char_language_model import CharLM | ||||
from .cnn_text_classification import CNNText | from .cnn_text_classification import CNNText | ||||
from .sequence_modeling import SeqLabeling, AdvSeqLabel | from .sequence_modeling import SeqLabeling, AdvSeqLabel | ||||
from .snli import SNLI | |||||
from .snli import ESIM |
@@ -134,17 +134,13 @@ class GraphParser(BaseModel): | |||||
def _mst_decoder(self, arc_matrix, mask=None): | def _mst_decoder(self, arc_matrix, mask=None): | ||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = torch.zeros_like(arc_matrix).copy_(arc_matrix) | |||||
matrix = arc_matrix.clone() | |||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | ||||
mask[batch_idx, lens-1] = 0 | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
len_i = lens[i] | len_i = lens[i] | ||||
if len_i == seq_len: | |||||
ans[i] = torch.as_tensor(mst(graph.cpu().numpy()), device=ans.device) | |||||
else: | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
ans[i, :len_i] = torch.as_tensor(mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | if mask is not None: | ||||
ans *= mask.long() | ans *= mask.long() | ||||
return ans | return ans | ||||
@@ -219,6 +215,7 @@ class BiaffineParser(GraphParser): | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | ||||
self.word_norm = nn.LayerNorm(word_hid_dim) | self.word_norm = nn.LayerNorm(word_hid_dim) | ||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | self.pos_norm = nn.LayerNorm(pos_hid_dim) | ||||
self.use_var_lstm = use_var_lstm | |||||
if use_var_lstm: | if use_var_lstm: | ||||
self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | self.lstm = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | ||||
hidden_size=rnn_hidden_size, | hidden_size=rnn_hidden_size, | ||||
@@ -249,10 +246,9 @@ class BiaffineParser(GraphParser): | |||||
self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | self.label_dep_mlp = copy.deepcopy(self.label_head_mlp) | ||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | ||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | ||||
self.normal_dropout = nn.Dropout(p=dropout) | |||||
self.use_greedy_infer = use_greedy_infer | self.use_greedy_infer = use_greedy_infer | ||||
self.reset_parameters() | self.reset_parameters() | ||||
self.explore_p = 0.2 | |||||
self.dropout = dropout | |||||
def reset_parameters(self): | def reset_parameters(self): | ||||
for m in self.modules(): | for m in self.modules(): | ||||
@@ -278,18 +274,15 @@ class BiaffineParser(GraphParser): | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | ||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
device = self.parameters().__next__().device | |||||
word_seq = word_seq.long().to(device) | |||||
pos_seq = pos_seq.long().to(device) | |||||
seq_lens = seq_lens.long().to(device).view(-1) | |||||
batch_size, seq_len = word_seq.shape | batch_size, seq_len = word_seq.shape | ||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
mask = seq_mask(seq_lens, seq_len).long() | mask = seq_mask(seq_lens, seq_len).long() | ||||
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0] | |||||
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1] | |||||
word = self.word_embedding(word_seq) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(pos_seq) # [N,L] -> [N,L,C_1] | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | x = torch.cat([word, pos], dim=2) # -> [N,L,C] | ||||
@@ -325,7 +318,7 @@ class BiaffineParser(GraphParser): | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
if torch.rand(1).item() < self.explore_p: | |||||
if gold_heads is None: | |||||
heads = self._greedy_decoder(arc_pred, mask) | heads = self._greedy_decoder(arc_pred, mask) | ||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
@@ -355,7 +348,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | flip_mask = (mask == 0) | ||||
_arc_pred = arc_pred.new_empty((batch_size, seq_len, seq_len)).copy_(arc_pred) | |||||
_arc_pred = arc_pred.clone() | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
label_logits = F.log_softmax(label_pred, dim=2) | label_logits = F.log_softmax(label_pred, dim=2) | ||||
@@ -421,7 +414,9 @@ class ParserMetric(MetricBase): | |||||
if seq_lens is None: | if seq_lens is None: | ||||
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | ||||
else: | else: | ||||
seq_mask = seq_lens_to_masks(seq_lens, float=False).long() | |||||
seq_mask = seq_lens_to_masks(seq_lens.long(), float=False).long() | |||||
# mask out <root> tag | |||||
seq_mask[:,0] = 0 | |||||
head_pred_correct = (arc_pred == arc_true).long() * seq_mask | head_pred_correct = (arc_pred == arc_true).long() * seq_mask | ||||
label_pred_correct = (label_pred == label_true).long() * head_pred_correct | label_pred_correct = (label_pred == label_true).long() * head_pred_correct | ||||
self.num_arc += head_pred_correct.sum().item() | self.num_arc += head_pred_correct.sum().item() | ||||
@@ -1,8 +1,8 @@ | |||||
import torch | import torch | ||||
import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
Advanced Sequence Labeling Model | Advanced Sequence Labeling Model | ||||
""" | """ | ||||
def __init__(self, args, emb=None): | |||||
def __init__(self, args, emb=None, id2words=None): | |||||
super(AdvSeqLabel, self).__init__(args) | super(AdvSeqLabel, self).__init__(args) | ||||
vocab_size = args["vocab_size"] | vocab_size = args["vocab_size"] | ||||
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | self.norm1 = torch.nn.LayerNorm(word_emb_dim) | ||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | ||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||||
bidirectional=True, batch_first=True) | |||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | ||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | ||||
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||||
self.drop = torch.nn.Dropout(dropout) | self.drop = torch.nn.Dropout(dropout) | ||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
if id2words is None: | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
else: | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
allowed_transitions=allowed_transitions(id2words, | |||||
encoding_type="bmes")) | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
""" | """ | ||||
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
assert 'loss' in kwargs | assert 'loss' in kwargs | ||||
return kwargs['loss'] | return kwargs['loss'] | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
args = { | args = { | ||||
'vocab_size': 20, | 'vocab_size': 20, | ||||
@@ -208,11 +215,11 @@ if __name__ == '__main__': | |||||
res = model(word_seq, word_seq_len, truth) | res = model(word_seq, word_seq_len, truth) | ||||
loss = res['loss'] | loss = res['loss'] | ||||
pred = res['predict'] | pred = res['predict'] | ||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
print('loss: {} acc {}'.format(loss.item(), | |||||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
optimizer.step() | optimizer.step() | ||||
curidx = endidx | curidx = endidx | ||||
if curidx == len(data): | if curidx == len(data): | ||||
curidx = 0 | curidx = 0 | ||||
@@ -3,29 +3,35 @@ import torch.nn as nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder as Decoder, encoder as Encoder | |||||
from fastNLP.modules import decoder as Decoder | |||||
from fastNLP.modules import encoder as Encoder | |||||
from fastNLP.modules import aggregator as Aggregator | |||||
my_inf = 10e12 | my_inf = 10e12 | ||||
class SNLI(BaseModel): | |||||
class ESIM(BaseModel): | |||||
""" | """ | ||||
PyTorch Network for SNLI. | |||||
PyTorch Network for SNLI task using ESIM model. | |||||
""" | """ | ||||
def __init__(self, args, init_embedding=None): | |||||
super(SNLI, self).__init__() | |||||
self.vocab_size = args["vocab_size"] | |||||
self.embed_dim = args["embed_dim"] | |||||
self.hidden_size = args["hidden_size"] | |||||
self.batch_first = args["batch_first"] | |||||
self.dropout = args["dropout"] | |||||
self.n_labels = args["num_classes"] | |||||
self.gpu = args["gpu"] and torch.cuda.is_available() | |||||
self.embedding = Encoder.embedding.Embedding(self.vocab_size, self.embed_dim, init_emb=init_embedding, | |||||
dropout=self.dropout) | |||||
def __init__(self, **kwargs): | |||||
super(ESIM, self).__init__() | |||||
self.vocab_size = kwargs["vocab_size"] | |||||
self.embed_dim = kwargs["embed_dim"] | |||||
self.hidden_size = kwargs["hidden_size"] | |||||
self.batch_first = kwargs["batch_first"] | |||||
self.dropout = kwargs["dropout"] | |||||
self.n_labels = kwargs["num_classes"] | |||||
self.gpu = kwargs["gpu"] and torch.cuda.is_available() | |||||
self.drop = nn.Dropout(self.dropout) | |||||
self.embedding = Encoder.Embedding( | |||||
self.vocab_size, self.embed_dim, dropout=self.dropout, | |||||
init_emb=kwargs["init_embedding"] if "inin_embedding" in kwargs.keys() else None, | |||||
) | |||||
self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) | self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) | ||||
@@ -34,6 +40,10 @@ class SNLI(BaseModel): | |||||
batch_first=self.batch_first, bidirectional=True | batch_first=self.batch_first, bidirectional=True | ||||
) | ) | ||||
self.bi_attention = Aggregator.Bi_Attention() | |||||
self.mean_pooling = Aggregator.MeanPoolWithMask() | |||||
self.max_pooling = Aggregator.MaxPoolWithMask() | |||||
self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size) | self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size) | ||||
self.decoder = Encoder.LSTM( | self.decoder = Encoder.LSTM( | ||||
@@ -41,16 +51,16 @@ class SNLI(BaseModel): | |||||
batch_first=self.batch_first, bidirectional=True | batch_first=self.batch_first, bidirectional=True | ||||
) | ) | ||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh') | |||||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) | |||||
def forward(self, premise, hypothesis, premise_len, hypothesis_len): | def forward(self, premise, hypothesis, premise_len, hypothesis_len): | ||||
""" Forward function | """ Forward function | ||||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL), hidden size(H)]. | |||||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL), H]. | |||||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. | |||||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. | |||||
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | :param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | ||||
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | :param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | ||||
:return: prediction: A Tensor of classification result: [B, n_labels(N)]. | |||||
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. | |||||
""" | """ | ||||
premise0 = self.embedding_layer(self.embedding(premise)) | premise0 = self.embedding_layer(self.embedding(premise)) | ||||
@@ -68,16 +78,13 @@ class SNLI(BaseModel): | |||||
B, PL, H = premise0.size() | B, PL, H = premise0.size() | ||||
B, HL, H = hypothesis0.size() | B, HL, H = hypothesis0.size() | ||||
# a0, (ah0, ac0) = self.encoder(premise) # a0: [B, PL, H * 2], ah0: [2, B, H] | |||||
# b0, (bh0, bc0) = self.encoder(hypothesis) # b0: [B, HL, H * 2] | |||||
a0 = self.encoder(premise0) # a0: [B, PL, H * 2] | |||||
b0 = self.encoder(hypothesis0) # b0: [B, HL, H * 2] | |||||
a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] | |||||
b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] | |||||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | ||||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | ||||
ai, bi = self.calc_bi_attention(a, b, premise_len, hypothesis_len) | |||||
ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) | |||||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | ||||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | ||||
@@ -85,17 +92,12 @@ class SNLI(BaseModel): | |||||
f_ma = self.inference_layer(ma) | f_ma = self.inference_layer(ma) | ||||
f_mb = self.inference_layer(mb) | f_mb = self.inference_layer(mb) | ||||
vat = self.decoder(f_ma) | |||||
vbt = self.decoder(f_mb) | |||||
vat = self.decoder(self.drop(f_ma)) | |||||
vbt = self.decoder(self.drop(f_mb)) | |||||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | ||||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | ||||
# va_ave = torch.mean(va, dim=1) # va_ave: [B, H] | |||||
# va_max, va_arg_max = torch.max(va, dim=1) # va_max: [B, H] | |||||
# vb_ave = torch.mean(vb, dim=1) # vb_ave: [B, H] | |||||
# vb_max, vb_arg_max = torch.max(vb, dim=1) # vb_max: [B, H] | |||||
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | ||||
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | ||||
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | ||||
@@ -103,59 +105,10 @@ class SNLI(BaseModel): | |||||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | ||||
# v_mlp = F.tanh(self.mlp_layer1(v)) # v_mlp: [B, H] | |||||
# prediction = self.mlp_layer2(v_mlp) # prediction: [B, N] | |||||
prediction = F.tanh(self.output(v)) # prediction: [B, N] | prediction = F.tanh(self.output(v)) # prediction: [B, N] | ||||
return prediction | |||||
@staticmethod | |||||
def calc_bi_attention(in_x1, in_x2, x1_len, x2_len): | |||||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||||
# x1_len: [batch_size, x1_seq_len] | |||||
# x2_len: [batch_size, x2_seq_len] | |||||
assert in_x1.size()[0] == in_x2.size()[0] | |||||
assert in_x1.size()[2] == in_x2.size()[2] | |||||
# The batch size and hidden size must be equal. | |||||
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | |||||
# The seq len in in_x and x_len must be equal. | |||||
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | |||||
batch_size = in_x1.size()[0] | |||||
x1_max_len = in_x1.size()[1] | |||||
x2_max_len = in_x2.size()[1] | |||||
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | |||||
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | |||||
a_mask = x1_len.le(0.5).float() * -my_inf # [batch_size, x1_seq_len] | |||||
a_mask = a_mask.view(batch_size, x1_max_len, -1) | |||||
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | |||||
b_mask = x2_len.le(0.5).float() * -my_inf | |||||
b_mask = b_mask.view(batch_size, -1, x2_max_len) | |||||
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | |||||
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | |||||
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | |||||
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | |||||
attention_b_t = torch.transpose(attention_b, 1, 2) | |||||
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | |||||
return out_x1, out_x2 | |||||
return {'pred': prediction} | |||||
@staticmethod | |||||
def mean_pooling(tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||||
return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) | |||||
def predict(self, premise, hypothesis, premise_len, hypothesis_len): | |||||
return self.forward(premise, hypothesis, premise_len, hypothesis_len) | |||||
@staticmethod | |||||
def max_pooling(tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||||
masks = masks.expand(-1, -1, tensor.size(2)).float() | |||||
return torch.max(tensor + masks.le(0.5).float() * -my_inf, dim=dim) |
@@ -3,12 +3,11 @@ from . import decoder | |||||
from . import encoder | from . import encoder | ||||
from .aggregator import * | from .aggregator import * | ||||
from .decoder import * | from .decoder import * | ||||
from .encoder import * | |||||
from .dropout import TimestepDropout | from .dropout import TimestepDropout | ||||
from .encoder import * | |||||
__version__ = '0.0.0' | __version__ = '0.0.0' | ||||
__all__ = ['encoder', | __all__ = ['encoder', | ||||
'decoder', | 'decoder', | ||||
'aggregator', | |||||
'TimestepDropout'] | |||||
'aggregator'] |
@@ -1,7 +1,10 @@ | |||||
from .max_pool import MaxPool | from .max_pool import MaxPool | ||||
from .max_pool import MaxPoolWithMask | |||||
from .avg_pool import AvgPool | from .avg_pool import AvgPool | ||||
from .avg_pool import MeanPoolWithMask | |||||
from .kmax_pool import KMaxPool | from .kmax_pool import KMaxPool | ||||
from .attention import Attention | from .attention import Attention | ||||
from .attention import Bi_Attention | |||||
from .self_attention import SelfAttention | from .self_attention import SelfAttention | ||||
@@ -1,6 +1,7 @@ | |||||
import math | import math | ||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -62,3 +63,46 @@ class MultiHeadAtte(nn.Module): | |||||
heads.append(headi) | heads.append(headi) | ||||
output = torch.cat(heads, dim=2) | output = torch.cat(heads, dim=2) | ||||
return self.out_linear(output) | return self.out_linear(output) | ||||
class Bi_Attention(nn.Module): | |||||
def __init__(self): | |||||
super(Bi_Attention, self).__init__() | |||||
self.inf = 10e12 | |||||
def forward(self, in_x1, in_x2, x1_len, x2_len): | |||||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||||
# x1_len: [batch_size, x1_seq_len] | |||||
# x2_len: [batch_size, x2_seq_len] | |||||
assert in_x1.size()[0] == in_x2.size()[0] | |||||
assert in_x1.size()[2] == in_x2.size()[2] | |||||
# The batch size and hidden size must be equal. | |||||
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | |||||
# The seq len in in_x and x_len must be equal. | |||||
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | |||||
batch_size = in_x1.size()[0] | |||||
x1_max_len = in_x1.size()[1] | |||||
x2_max_len = in_x2.size()[1] | |||||
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | |||||
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | |||||
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] | |||||
a_mask = a_mask.view(batch_size, x1_max_len, -1) | |||||
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | |||||
b_mask = x2_len.le(0.5).float() * -self.inf | |||||
b_mask = b_mask.view(batch_size, -1, x2_max_len) | |||||
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | |||||
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | |||||
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | |||||
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | |||||
attention_b_t = torch.transpose(attention_b, 1, 2) | |||||
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | |||||
return out_x1, out_x2 |
@@ -1,6 +1,7 @@ | |||||
# python: 3.6 | # python: 3.6 | ||||
# encoding: utf-8 | # encoding: utf-8 | ||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
@@ -22,3 +23,14 @@ class AvgPool(nn.Module): | |||||
stride=self.stride, | stride=self.stride, | ||||
padding=self.padding) | padding=self.padding) | ||||
return x.squeeze(dim=-1) | return x.squeeze(dim=-1) | ||||
class MeanPoolWithMask(nn.Module): | |||||
def __init__(self): | |||||
super(MeanPoolWithMask, self).__init__() | |||||
self.inf = 10e12 | |||||
def forward(self, tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||||
return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) | |||||
@@ -25,3 +25,14 @@ class MaxPool(nn.Module): | |||||
padding=self.padding, | padding=self.padding, | ||||
dilation=self.dilation) | dilation=self.dilation) | ||||
return x.squeeze(dim=-1) # [N,C,1] -> [N,C] | return x.squeeze(dim=-1) # [N,C,1] -> [N,C] | ||||
class MaxPoolWithMask(nn.Module): | |||||
def __init__(self): | |||||
super(MaxPoolWithMask, self).__init__() | |||||
self.inf = 10e12 | |||||
def forward(self, tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||||
masks = masks.expand(-1, -1, tensor.size(2)).float() | |||||
return torch.max(tensor + masks.le(0.5).float() * -self.inf, dim=dim) |
@@ -15,33 +15,154 @@ def seq_len_to_byte_mask(seq_lens): | |||||
# return value: ByteTensor, batch_size x max_len | # return value: ByteTensor, batch_size x max_len | ||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1) | |||||
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) | |||||
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) | |||||
return mask | return mask | ||||
def allowed_transitions(id2label, encoding_type='bio'): | |||||
""" | |||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||||
:param encoding_type: str, 支持"bio", "bmes"。 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
""" | |||||
num_tags = len(id2label) | |||||
start_idx = num_tags | |||||
end_idx = num_tags + 1 | |||||
encoding_type = encoding_type.lower() | |||||
allowed_trans = [] | |||||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||||
def split_tag_label(from_label): | |||||
from_label = from_label.lower() | |||||
if from_label in ['start', 'end']: | |||||
from_tag = from_label | |||||
from_label = '' | |||||
else: | |||||
from_tag = from_label[:1] | |||||
from_label = from_label[2:] | |||||
return from_tag, from_label | |||||
for from_id, from_label in id_label_lst: | |||||
if from_label in ['<pad>', '<unk>']: | |||||
continue | |||||
from_tag, from_label = split_tag_label(from_label) | |||||
for to_id, to_label in id_label_lst: | |||||
if to_label in ['<pad>', '<unk>']: | |||||
continue | |||||
to_tag, to_label = split_tag_label(to_label) | |||||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
allowed_trans.append((from_id, to_id)) | |||||
return allowed_trans | |||||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
""" | |||||
:param encoding_type: str, 支持"BIO", "BMES"。 | |||||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param from_label: str, 比如"PER", "LOC"等label | |||||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param to_label: str, 比如"PER", "LOC"等label | |||||
:return: bool,能否跃迁 | |||||
""" | |||||
if to_tag=='start' or from_tag=='end': | |||||
return False | |||||
encoding_type = encoding_type.lower() | |||||
if encoding_type == 'bio': | |||||
""" | |||||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+-------+-----+ | |||||
| | B | I | O | start | end | | |||||
+-------+---+---+---+-------+-----+ | |||||
| B | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| I | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| O | y | n | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| start | y | n | y | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ('b', 'o') | |||||
elif from_tag in ['b', 'i']: | |||||
return any([to_tag in ['end', 'b', 'o'], to_tag=='i' and from_label==to_label]) | |||||
elif from_tag == 'o': | |||||
return to_tag in ['end', 'b', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||||
elif encoding_type == 'bmes': | |||||
""" | |||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| | B | M | E | S | start | end | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| B | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| M | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| E | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| S | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| start | y | n | n | y | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label==to_label | |||||
elif from_tag in ['e', 's']: | |||||
return to_tag in ['b', 's', 'end'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||||
else: | |||||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
""" | """ | ||||
:param int tag_size: num of tags | |||||
:param bool include_start_end_trans: whether to include start/end tag | |||||
:param str initial_method: method for initialization | |||||
:param int num_tags: 标签的数量。 | |||||
:param bool include_start_end_trans: 是否包含起始tag | |||||
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 | |||||
如果为None,则所有跃迁均为合法 | |||||
:param str initial_method: | |||||
""" | """ | ||||
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): | |||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
self.tag_size = tag_size | |||||
self.num_tags = num_tags | |||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | ||||
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) | |||||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
self.start_scores = nn.Parameter(torch.randn(tag_size)) | |||||
self.end_scores = nn.Parameter(torch.randn(tag_size)) | |||||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||||
if allowed_transitions is None: | |||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||||
else: | |||||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||||
for from_tag_id, to_tag_id in allowed_transitions: | |||||
constrain[from_tag_id, to_tag_id] = 0 | |||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||||
# self.reset_parameter() | # self.reset_parameter() | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | def reset_parameter(self): | ||||
nn.init.xavier_normal_(self.trans_m) | nn.init.xavier_normal_(self.trans_m) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -52,9 +173,9 @@ class ConditionalRandomField(nn.Module): | |||||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | """Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param FloatTensor logits: [max_len, batch_size, tag_size] | |||||
:param ByteTensor mask: [max_len, batch_size] | |||||
:return: FloatTensor, [batch_size,] | |||||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||||
:param mask:ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | |||||
""" | """ | ||||
seq_len, batch_size, n_tags = logits.size() | seq_len, batch_size, n_tags = logits.size() | ||||
alpha = logits[0] | alpha = logits[0] | ||||
@@ -73,9 +194,9 @@ class ConditionalRandomField(nn.Module): | |||||
return log_sum_exp(alpha, 1) | return log_sum_exp(alpha, 1) | ||||
def _glod_score(self, logits, tags, mask): | def _glod_score(self, logits, tags, mask): | ||||
"""Compute the score for the gold path. | |||||
:param logits: FloatTensor, max_len x batch_size x tag_size | |||||
""" | |||||
Compute the score for the gold path. | |||||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||||
:param tags: LongTensor, max_len x batch_size | :param tags: LongTensor, max_len x batch_size | ||||
:param mask: ByteTensor, max_len x batch_size | :param mask: ByteTensor, max_len x batch_size | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -100,12 +221,12 @@ class ConditionalRandomField(nn.Module): | |||||
return score | return score | ||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
"""Calculate the neg log likelihood | |||||
:param FloatTensor feats: [batch_size, max_len, tag_size] | |||||
:param LongTensor tags: [batch_size, max_len] | |||||
:param ByteTensor mask: [batch_size, max_len] | |||||
:return: FloatTensor, [batch_size,] | |||||
""" | |||||
Calculate the neg log likelihood | |||||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||||
:param tags:LongTensor, batch_size x max_len | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:return:FloatTensor, batch_size | |||||
""" | """ | ||||
feats = feats.transpose(0, 1) | feats = feats.transpose(0, 1) | ||||
tags = tags.transpose(0, 1).long() | tags = tags.transpose(0, 1).long() | ||||
@@ -115,13 +236,20 @@ class ConditionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, data, mask, get_score=False): | |||||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||||
"""Given a feats matrix, return best decode path and best score. | """Given a feats matrix, return best decode path and best score. | ||||
:param FloatTensor data: [batch_size, max_len, tag_size] | |||||
:param ByteTensor mask: [batch_size, max_len] | |||||
:param bool get_score: whether to output the decode score. | |||||
:return: scores, paths | |||||
:param data:FloatTensor, batch_size x max_len x num_tags | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:param get_score: bool, whether to output the decode score. | |||||
:param unpad: bool, 是否将结果unpad, | |||||
如果False, 返回的是batch_size x max_len的tensor, | |||||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||||
List[int]的长度是这个sample的有效长度 | |||||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||||
为每个seqence的解码分数。 | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = data.size() | batch_size, seq_len, n_tags = data.size() | ||||
data = data.transpose(0, 1).data # L, B, H | data = data.transpose(0, 1).data # L, B, H | ||||
@@ -130,19 +258,23 @@ class ConditionalRandomField(nn.Module): | |||||
# dp | # dp | ||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
vscore = data[0] | vscore = data[0] | ||||
transitions = self._constrain.data.clone() | |||||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
vscore += self.start_scores.view(1, -1) | |||||
transitions[n_tags, :n_tags] += self.start_scores.data | |||||
transitions[:n_tags, n_tags+1] += self.end_scores.data | |||||
vscore += transitions[n_tags, :n_tags] | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
cur_score = data[i].view(batch_size, 1, n_tags) | cur_score = data[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags).data | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) | ||||
if self.include_start_end_trans: | |||||
vscore += self.end_scores.view(1, -1) | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | ||||
@@ -157,7 +289,13 @@ class ConditionalRandomField(nn.Module): | |||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | last_tags = vpath[idxes[i], batch_idx, last_tags] | ||||
ans[idxes[i+1], batch_idx] = last_tags | ans[idxes[i+1], batch_idx] = last_tags | ||||
ans = ans.transpose(0, 1) | |||||
if unpad: | |||||
paths = [] | |||||
for idx, seq_len in enumerate(lens): | |||||
paths.append(ans[idx, :seq_len+1].tolist()) | |||||
else: | |||||
paths = ans | |||||
if get_score: | if get_score: | ||||
return ans_score, ans.transpose(0, 1) | |||||
return ans.transpose(0, 1) | |||||
return paths, ans_score.tolist() | |||||
return paths |
@@ -35,7 +35,7 @@ class MLP(nn.Module): | |||||
} | } | ||||
if activation in actives: | if activation in actives: | ||||
self.hidden_active = actives[activation] | self.hidden_active = actives[activation] | ||||
elif isinstance(activation, callable): | |||||
elif callable(activation): | |||||
self.hidden_active = activation | self.hidden_active = activation | ||||
else: | else: | ||||
raise ValueError("should set activation correctly: {}".format(activation)) | raise ValueError("should set activation correctly: {}".format(activation)) | ||||
@@ -1,4 +1,2 @@ | |||||
from .CRF import ConditionalRandomField | from .CRF import ConditionalRandomField | ||||
from .MLP import MLP | from .MLP import MLP | ||||
__all__ = ["ConditionalRandomField", "MLP"] |
@@ -1,7 +1,6 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.utils.rnn import PackedSequence | |||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
try: | try: | ||||
@@ -25,30 +24,63 @@ class VarRnnCellWrapper(nn.Module): | |||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input, hidden, mask_x=None, mask_h=None): | |||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||||
""" | """ | ||||
:param input: [seq_len, batch_size, input_size] | |||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | ||||
for other RNN, h_0, [batch_size, hidden_size] | for other RNN, h_0, [batch_size, hidden_size] | ||||
:param mask_x: [batch_size, input_size] dropout mask for input | :param mask_x: [batch_size, input_size] dropout mask for input | ||||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | :param mask_h: [batch_size, hidden_size] dropout mask for hidden | ||||
:return: (output, hidden) | |||||
**output**: [seq_len, bacth_size, hidden_size]. | |||||
**hidden**: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]; For other RNN, h_n, [batch_size, hidden_size]. | |||||
:return PackedSequence output: [seq_len, bacth_size, hidden_size] | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
for other RNN, h_n, [batch_size, hidden_size] | |||||
""" | """ | ||||
def get_hi(hi, h0, size): | |||||
h0_size = size - hi.size(0) | |||||
if h0_size > 0: | |||||
return torch.cat([hi, h0[:h0_size]], dim=0) | |||||
return hi[:size] | |||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input = input * mask_x.unsqueeze(0) if mask_x is not None else input | |||||
output_list = [] | |||||
for x in input: | |||||
input, batch_sizes = input_x | |||||
output = [] | |||||
cell = self.cell | |||||
if is_reversed: | |||||
batch_iter = flip(batch_sizes, [0]) | |||||
idx = input.size(0) | |||||
else: | |||||
batch_iter = batch_sizes | |||||
idx = 0 | |||||
if is_lstm: | |||||
hn = (hidden[0].clone(), hidden[1].clone()) | |||||
else: | |||||
hn = hidden.clone() | |||||
hi = hidden | |||||
for size in batch_iter: | |||||
if is_reversed: | |||||
input_i = input[idx-size: idx] * mask_x[:size] | |||||
idx -= size | |||||
else: | |||||
input_i = input[idx: idx+size] * mask_x[:size] | |||||
idx += size | |||||
mask_hi = mask_h[:size] | |||||
if is_lstm: | if is_lstm: | ||||
hx, cx = hidden | |||||
hidden = (hx * mask_h, cx) if mask_h is not None else (hx, cx) | |||||
hx, cx = hi | |||||
hi = (get_hi(hx, hidden[0], size) * mask_hi, get_hi(cx, hidden[1], size)) | |||||
hi = cell(input_i, hi) | |||||
hn[0][:size] = hi[0] | |||||
hn[1][:size] = hi[1] | |||||
output.append(hi[0]) | |||||
else: | else: | ||||
hidden *= mask_h if mask_h is not None else hidden | |||||
hidden = self.cell(x, hidden) | |||||
output_list.append(hidden[0] if is_lstm else hidden) | |||||
output = torch.stack(output_list, dim=0) | |||||
return output, hidden | |||||
hi = get_hi(hi, hidden, size) * mask_hi | |||||
hi = cell(input_i, hi) | |||||
hn[:size] = hi | |||||
output.append(hi) | |||||
if is_reversed: | |||||
output = list(reversed(output)) | |||||
output = torch.cat(output, dim=0) | |||||
return PackedSequence(output, batch_sizes), hn | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
@@ -78,60 +110,67 @@ class VarRNNBase(nn.Module): | |||||
cell = Cell(input_size, self.hidden_size, bias) | cell = Cell(input_size, self.hidden_size, bias) | ||||
self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | self._all_cells.append(VarRnnCellWrapper(cell, self.hidden_size, input_dropout, hidden_dropout)) | ||||
initial_parameter(self) | initial_parameter(self) | ||||
self.is_lstm = (self.mode == "LSTM") | |||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||||
is_lstm = self.is_lstm | |||||
idx = self.num_directions * n_layer + n_direction | |||||
cell = self._all_cells[idx] | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||||
return output_x, hidden_x | |||||
def forward(self, input, hx=None): | def forward(self, input, hx=None): | ||||
is_lstm = self.is_lstm | |||||
is_packed = isinstance(input, PackedSequence) | is_packed = isinstance(input, PackedSequence) | ||||
is_lstm = (self.mode == "LSTM") | |||||
if is_packed: | |||||
input, batch_sizes = input | |||||
max_batch_size = int(batch_sizes[0]) | |||||
else: | |||||
batch_sizes = None | |||||
if not is_packed: | |||||
seq_len = input.size(1) if self.batch_first else input.size(0) | |||||
max_batch_size = input.size(0) if self.batch_first else input.size(1) | max_batch_size = input.size(0) if self.batch_first else input.size(1) | ||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | |||||
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) | |||||
else: | |||||
max_batch_size = int(input.batch_sizes[0]) | |||||
input, batch_sizes = input | |||||
if hx is None: | if hx is None: | ||||
hx = input.new_zeros(self.num_layers * self.num_directions, | hx = input.new_zeros(self.num_layers * self.num_directions, | ||||
max_batch_size, self.hidden_size, | |||||
requires_grad=False) | |||||
max_batch_size, self.hidden_size, requires_grad=True) | |||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx) | |||||
if self.batch_first: | |||||
input = input.transpose(0, 1) | |||||
batch_size = input.shape[1] | |||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||||
mask_x = input.new_ones((batch_size, self.input_size)) | |||||
mask_out = input.new_ones((batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = input.new_ones((batch_size, self.hidden_size)) | |||||
mask_x = input.new_ones((max_batch_size, self.input_size)) | |||||
mask_out = input.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = input.new_ones((max_batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
hidden_list = [] | |||||
hidden = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
if is_lstm: | |||||
cellstate = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
input_seq = PackedSequence(input, batch_sizes) | |||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | ||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
input_x = input if direction == 0 else flip(input, [0]) | |||||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | |||||
mask_x if layer == 0 else mask_out, mask_h) | |||||
output_list.append(output_x.data) | |||||
idx = self.num_directions * layer + direction | idx = self.num_directions * layer + direction | ||||
cell = self._all_cells[idx] | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
mask_xi = mask_x if layer == 0 else mask_out | |||||
output_x, hidden_x = cell(input_x, hi, mask_xi, mask_h) | |||||
output_list.append(output_x if direction == 0 else flip(output_x, [0])) | |||||
hidden_list.append(hidden_x) | |||||
if is_lstm: | |||||
hidden[idx] = hidden_x[0] | |||||
cellstate[idx] = hidden_x[1] | |||||
else: | |||||
hidden[idx] = hidden_x | |||||
input = torch.cat(output_list, dim=-1) | input = torch.cat(output_list, dim=-1) | ||||
output = input.transpose(0, 1) if self.batch_first else input | |||||
if is_lstm: | if is_lstm: | ||||
h_list, c_list = zip(*hidden_list) | |||||
hn = torch.stack(h_list, dim=0) | |||||
cn = torch.stack(c_list, dim=0) | |||||
hidden = (hn, cn) | |||||
else: | |||||
hidden = torch.stack(hidden_list, dim=0) | |||||
hidden = (hidden, cellstate) | |||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(output, batch_sizes) | |||||
output = PackedSequence(input, batch_sizes) | |||||
else: | |||||
input = PackedSequence(input, batch_sizes) | |||||
output, _ = pad_packed_sequence(input, batch_first=self.batch_first) | |||||
return output, hidden | return output, hidden | ||||
@@ -158,3 +197,36 @@ class VarGRU(VarRNNBase): | |||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | ||||
# if __name__ == '__main__': | |||||
# x = torch.Tensor([[1,2,3], [4,5,0], [6,0,0]])[:,:,None] * 0.1 | |||||
# mask = (x != 0).float().view(3, -1) | |||||
# seq_lens = torch.LongTensor([3,2,1]) | |||||
# y = torch.Tensor([[0,1,1], [1,1,0], [0,0,0]]) | |||||
# # rev = _reverse_packed_sequence(pack) | |||||
# # # print(rev) | |||||
# lstm = VarLSTM(input_size=1, num_layers=2, hidden_size=2, | |||||
# batch_first=True, bidirectional=True, | |||||
# input_dropout=0.0, hidden_dropout=0.0,) | |||||
# # lstm = nn.LSTM(input_size=1, num_layers=2, hidden_size=2, | |||||
# # batch_first=True, bidirectional=True,) | |||||
# loss = nn.BCELoss() | |||||
# m = nn.Sigmoid() | |||||
# optim = torch.optim.SGD(lstm.parameters(), lr=1e-3) | |||||
# for i in range(2000): | |||||
# optim.zero_grad() | |||||
# pack = pack_padded_sequence(x, seq_lens, batch_first=True) | |||||
# out, hidden = lstm(pack) | |||||
# out, lens = pad_packed_sequence(out, batch_first=True) | |||||
# # print(lens) | |||||
# # print(out) | |||||
# # print(hidden[0]) | |||||
# # print(hidden[0].size()) | |||||
# # print(hidden[1]) | |||||
# out = out.sum(-1) | |||||
# out = m(out) * mask | |||||
# l = loss(out, y) | |||||
# l.backward() | |||||
# optim.step() | |||||
# if i % 50 == 0: | |||||
# print(out) |
@@ -1,13 +1,8 @@ | |||||
[train] | [train] | ||||
epochs = -1 | |||||
batch_size = 16 | |||||
pickle_path = "./save/" | |||||
validate = true | |||||
save_best_dev = true | |||||
eval_sort_key = "UAS" | |||||
n_epochs = 40 | |||||
batch_size = 32 | |||||
use_cuda = true | use_cuda = true | ||||
model_saved_path = "./save/" | |||||
print_every_step = 20 | |||||
validate_every = 500 | |||||
use_golden_train=true | use_golden_train=true | ||||
[test] | [test] | ||||
@@ -32,9 +27,9 @@ arc_mlp_size = 500 | |||||
label_mlp_size = 100 | label_mlp_size = 100 | ||||
num_label = -1 | num_label = -1 | ||||
dropout = 0.33 | dropout = 0.33 | ||||
use_var_lstm=false | |||||
use_var_lstm=true | |||||
use_greedy_infer=false | use_greedy_infer=false | ||||
[optim] | [optim] | ||||
lr = 2e-3 | |||||
weight_decay = 5e-5 | |||||
lr = 3e-4 | |||||
;weight_decay = 3e-5 |
@@ -3,24 +3,26 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
import fastNLP | |||||
import torch | import torch | ||||
import re | |||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.field import TextField, SeqLabelField | |||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | from fastNLP.io.config_io import ConfigLoader, ConfigSection | ||||
from fastNLP.io.model_io import ModelLoader, ModelSaver | |||||
from fastNLP.io.model_io import ModelLoader | |||||
from fastNLP.io.embed_loader import EmbedLoader | from fastNLP.io.embed_loader import EmbedLoader | ||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.io.model_io import ModelSaver | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader | |||||
from fastNLP.api.processor import * | |||||
BOS = '<BOS>' | BOS = '<BOS>' | ||||
EOS = '<EOS>' | EOS = '<EOS>' | ||||
UNK = '<OOV>' | |||||
UNK = '<UNK>' | |||||
NUM = '<NUM>' | NUM = '<NUM>' | ||||
ENG = '<ENG>' | ENG = '<ENG>' | ||||
@@ -28,85 +30,25 @@ ENG = '<ENG>' | |||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
class ConlluDataLoader(object): | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet(name='conll') | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
ds.append(Instance(word_seq=TextField(res[0], is_target=False), | |||||
pos_seq=TextField(res[1], is_target=False), | |||||
head_indices=SeqLabelField(res[2], is_target=True), | |||||
head_labels=TextField(res[3], is_target=True))) | |||||
return ds | |||||
def get_one(self, sample): | |||||
text = [] | |||||
pos_tags = [] | |||||
heads = [] | |||||
head_tags = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
continue | |||||
text.append(t1) | |||||
pos_tags.append(t2) | |||||
heads.append(int(t3)) | |||||
head_tags.append(t4) | |||||
return (text, pos_tags, heads, head_tags) | |||||
class CTBDataLoader(object): | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
def parse(self, lines): | |||||
""" | |||||
[ | |||||
[word], [pos], [head_index], [head_tag] | |||||
] | |||||
""" | |||||
sample = [] | |||||
data = [] | |||||
for i, line in enumerate(lines): | |||||
line = line.strip() | |||||
if len(line) == 0 or i+1 == len(lines): | |||||
data.append(list(map(list, zip(*sample)))) | |||||
sample = [] | |||||
else: | |||||
sample.append(line.split()) | |||||
return data | |||||
def convert(self, data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] + [EOS] | |||||
pos_seq = [BOS] + sample[1] + [EOS] | |||||
heads = [0] + list(map(int, sample[2])) + [0] | |||||
head_tags = [BOS] + sample[3] + [EOS] | |||||
dataset.append(Instance(word_seq=TextField(word_seq, is_target=False), | |||||
pos_seq=TextField(pos_seq, is_target=False), | |||||
gold_heads=SeqLabelField(heads, is_target=False), | |||||
head_indices=SeqLabelField(heads, is_target=True), | |||||
head_labels=TextField(head_tags, is_target=True))) | |||||
return dataset | |||||
def convert(data): | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + list(map(int, sample[2])) | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(words=word_seq, | |||||
pos=pos_seq, | |||||
gold_heads=heads, | |||||
arc_true=heads, | |||||
tags=head_tags)) | |||||
return dataset | |||||
def load(path): | |||||
data = ConllxDataLoader().load(path) | |||||
return convert(data) | |||||
# datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | # datadir = "/mnt/c/Me/Dev/release-2.2-st-train-dev-data/ud-treebanks-v2.2/UD_English-EWT" | ||||
# datadir = "/home/yfshao/UD_English-EWT" | # datadir = "/home/yfshao/UD_English-EWT" | ||||
@@ -115,26 +57,29 @@ class CTBDataLoader(object): | |||||
# emb_file_name = '/home/yfshao/glove.6B.100d.txt' | # emb_file_name = '/home/yfshao/glove.6B.100d.txt' | ||||
# loader = ConlluDataLoader() | # loader = ConlluDataLoader() | ||||
datadir = '/home/yfshao/workdir/parser-data/' | |||||
train_data_name = "train_ctb5.txt" | |||||
dev_data_name = "dev_ctb5.txt" | |||||
test_data_name = "test_ctb5.txt" | |||||
emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
# emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
loader = CTBDataLoader() | |||||
# datadir = '/home/yfshao/workdir/parser-data/' | |||||
# train_data_name = "train_ctb5.txt" | |||||
# dev_data_name = "dev_ctb5.txt" | |||||
# test_data_name = "test_ctb5.txt" | |||||
datadir = "/home/yfshao/workdir/ctb7.0/" | |||||
train_data_name = "train.conllx" | |||||
dev_data_name = "dev.conllx" | |||||
test_data_name = "test.conllx" | |||||
# emb_file_name = "/home/yfshao/workdir/parser-data/word_OOVthr_30_100v.txt" | |||||
emb_file_name = "/home/yfshao/workdir/word_vector/cc.zh.300.vec" | |||||
cfgfile = './cfg.cfg' | cfgfile = './cfg.cfg' | ||||
processed_datadir = './save' | processed_datadir = './save' | ||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
test_args = ConfigSection() | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
optim_args = ConfigSection() | optim_args = ConfigSection() | ||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "test": test_args, "model": model_args, "optim": optim_args}) | |||||
ConfigLoader.load_config(cfgfile, {"train": train_args, "model": model_args, "optim": optim_args}) | |||||
print('trainre Args:', train_args.data) | print('trainre Args:', train_args.data) | ||||
print('test Args:', test_args.data) | |||||
print('optim Args:', optim_args.data) | |||||
print('model Args:', model_args.data) | |||||
print('optim_args', optim_args.data) | |||||
# Pickle Loader | # Pickle Loader | ||||
@@ -159,84 +104,36 @@ def load_data(dirpath): | |||||
return datas | return datas | ||||
def P2(data, field, length): | def P2(data, field, length): | ||||
ds = [ins for ins in data if ins[field].get_length() >= length] | |||||
ds = [ins for ins in data if len(ins[field]) >= length] | |||||
data.clear() | data.clear() | ||||
data.extend(ds) | data.extend(ds) | ||||
return ds | return ds | ||||
def P1(data, field): | |||||
def reeng(w): | |||||
return w if w == BOS or w == EOS or re.search(r'^([a-zA-Z]+[\.\-]*)+$', w) is None else ENG | |||||
def renum(w): | |||||
return w if re.search(r'^[0-9]+\.?[0-9]*$', w) is None else NUM | |||||
for ins in data: | |||||
ori = ins[field].contents() | |||||
s = list(map(renum, map(reeng, ori))) | |||||
if s != ori: | |||||
# print(ori) | |||||
# print(s) | |||||
# print() | |||||
ins[field] = ins[field].new(s) | |||||
return data | |||||
class ParserEvaluator(Evaluator): | |||||
def __init__(self, ignore_label): | |||||
super(ParserEvaluator, self).__init__() | |||||
self.ignore = ignore_label | |||||
def __call__(self, predict_list, truth_list): | |||||
head_all, label_all, total_all = 0, 0, 0 | |||||
for pred, truth in zip(predict_list, truth_list): | |||||
head, label, total = self.evaluate(**pred, **truth) | |||||
head_all += head | |||||
label_all += label | |||||
total_all += total | |||||
return {'UAS': head_all*1.0 / total_all, 'LAS': label_all*1.0 / total_all} | |||||
def evaluate(self, head_pred, label_pred, head_indices, head_labels, seq_mask, **_): | |||||
""" | |||||
Evaluate the performance of prediction. | |||||
:return : performance results. | |||||
head_pred_corrct: number of correct predicted heads. | |||||
label_pred_correct: number of correct predicted labels. | |||||
total_tokens: number of predicted tokens | |||||
""" | |||||
seq_mask *= (head_labels != self.ignore).long() | |||||
head_pred_correct = (head_pred == head_indices).long() * seq_mask | |||||
_, label_preds = torch.max(label_pred, dim=2) | |||||
label_pred_correct = (label_preds == head_labels).long() * head_pred_correct | |||||
return head_pred_correct.sum().item(), label_pred_correct.sum().item(), seq_mask.sum().item() | |||||
try: | |||||
data_dict = load_data(processed_datadir) | |||||
word_v = data_dict['word_v'] | |||||
pos_v = data_dict['pos_v'] | |||||
tag_v = data_dict['tag_v'] | |||||
train_data = data_dict['train_data'] | |||||
dev_data = data_dict['dev_data'] | |||||
test_data = data_dict['test_data'] | |||||
print('use saved pickles') | |||||
except Exception as _: | |||||
print('load raw data and preprocess') | |||||
# use pretrain embedding | |||||
word_v = Vocabulary(need_default=True, min_freq=2) | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary(need_default=True) | |||||
tag_v = Vocabulary(need_default=False) | |||||
train_data = loader.load(os.path.join(datadir, train_data_name)) | |||||
dev_data = loader.load(os.path.join(datadir, dev_data_name)) | |||||
test_data = loader.load(os.path.join(datadir, test_data_name)) | |||||
train_data.update_vocab(word_seq=word_v, pos_seq=pos_v, head_labels=tag_v) | |||||
datasets = (train_data, dev_data, test_data) | |||||
save_data(processed_datadir, word_v=word_v, pos_v=pos_v, tag_v=tag_v, train_data=train_data, dev_data=dev_data, test_data=test_data) | |||||
embed, _ = EmbedLoader.load_embedding(model_args['word_emb_dim'], emb_file_name, 'glove', word_v, os.path.join(processed_datadir, 'word_emb.pkl')) | |||||
print(len(word_v)) | |||||
print(embed.size()) | |||||
def update_v(vocab, data, field): | |||||
data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None) | |||||
print('load raw data and preprocess') | |||||
# use pretrain embedding | |||||
word_v = Vocabulary() | |||||
word_v.unknown_label = UNK | |||||
pos_v = Vocabulary() | |||||
tag_v = Vocabulary(unknown=None, padding=None) | |||||
train_data = load(os.path.join(datadir, train_data_name)) | |||||
dev_data = load(os.path.join(datadir, dev_data_name)) | |||||
test_data = load(os.path.join(datadir, test_data_name)) | |||||
print(train_data[0]) | |||||
num_p = Num2TagProcessor('words', 'words') | |||||
for ds in (train_data, dev_data, test_data): | |||||
num_p(ds) | |||||
update_v(word_v, train_data, 'words') | |||||
update_v(pos_v, train_data, 'pos') | |||||
update_v(tag_v, train_data, 'tags') | |||||
print('vocab build success {}, {}, {}'.format(len(word_v), len(pos_v), len(tag_v))) | |||||
# embed, _ = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) | |||||
# print(embed.size()) | |||||
# Model | # Model | ||||
model_args['word_vocab_size'] = len(word_v) | model_args['word_vocab_size'] = len(word_v) | ||||
@@ -245,50 +142,49 @@ model_args['num_label'] = len(tag_v) | |||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
model.reset_parameters() | model.reset_parameters() | ||||
datasets = (train_data, dev_data, test_data) | |||||
for ds in datasets: | |||||
ds.index_field("word_seq", word_v).index_field("pos_seq", pos_v).index_field("head_labels", tag_v) | |||||
ds.set_origin_len('word_seq') | |||||
word_idxp = IndexerProcessor(word_v, 'words', 'word_seq') | |||||
pos_idxp = IndexerProcessor(pos_v, 'pos', 'pos_seq') | |||||
tag_idxp = IndexerProcessor(tag_v, 'tags', 'label_true') | |||||
seq_p = SeqLenProcessor('word_seq', 'seq_lens') | |||||
set_input_p = SetInputProcessor('word_seq', 'pos_seq', 'seq_lens', flag=True) | |||||
set_target_p = SetTargetProcessor('arc_true', 'label_true', 'seq_lens', flag=True) | |||||
label_toword_p = Index2WordProcessor(vocab=tag_v, field_name='label_pred', new_added_field_name='label_pred_seq') | |||||
for ds in (train_data, dev_data, test_data): | |||||
word_idxp(ds) | |||||
pos_idxp(ds) | |||||
tag_idxp(ds) | |||||
seq_p(ds) | |||||
set_input_p(ds) | |||||
set_target_p(ds) | |||||
if train_args['use_golden_train']: | if train_args['use_golden_train']: | ||||
train_data.set_target(gold_heads=False) | |||||
else: | |||||
train_data.set_target(gold_heads=None) | |||||
train_data.set_input('gold_heads', flag=True) | |||||
train_args.data.pop('use_golden_train') | train_args.data.pop('use_golden_train') | ||||
ignore_label = pos_v['P'] | |||||
ignore_label = pos_v['punct'] | |||||
print(test_data[0]) | print(test_data[0]) | ||||
print(len(train_data)) | |||||
print(len(dev_data)) | |||||
print(len(test_data)) | |||||
print('train len {}'.format(len(train_data))) | |||||
print('dev len {}'.format(len(dev_data))) | |||||
print('test len {}'.format(len(test_data))) | |||||
def train(path): | def train(path): | ||||
# test saving pipeline | |||||
save_pipe(path) | |||||
# Trainer | # Trainer | ||||
trainer = Trainer(**train_args.data) | |||||
def _define_optim(obj): | |||||
lr = optim_args.data['lr'] | |||||
embed_params = set(obj._model.word_embedding.parameters()) | |||||
decay_params = set(obj._model.arc_predictor.parameters()) | set(obj._model.label_predictor.parameters()) | |||||
params = [p for p in obj._model.parameters() if p not in decay_params and p not in embed_params] | |||||
obj._optimizer = torch.optim.Adam([ | |||||
{'params': list(embed_params), 'lr':lr*0.1}, | |||||
{'params': list(decay_params), **optim_args.data}, | |||||
{'params': params} | |||||
], lr=lr, betas=(0.9, 0.9)) | |||||
obj._scheduler = torch.optim.lr_scheduler.LambdaLR(obj._optimizer, lambda ep: max(.75 ** (ep / 5e4), 0.05)) | |||||
def _update(obj): | |||||
# torch.nn.utils.clip_grad_norm_(obj._model.parameters(), 5.0) | |||||
obj._scheduler.step() | |||||
obj._optimizer.step() | |||||
trainer.define_optimizer = lambda: _define_optim(trainer) | |||||
trainer.update = lambda: _update(trainer) | |||||
trainer.set_validator(Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label))) | |||||
model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path) | |||||
# model.word_embedding = torch.nn.Embedding.from_pretrained(embed, freeze=False) | |||||
model.word_embedding.padding_idx = word_v.padding_idx | model.word_embedding.padding_idx = word_v.padding_idx | ||||
model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | model.word_embedding.weight.data[word_v.padding_idx].fill_(0) | ||||
model.pos_embedding.padding_idx = pos_v.padding_idx | model.pos_embedding.padding_idx = pos_v.padding_idx | ||||
@@ -302,18 +198,23 @@ def train(path): | |||||
# pass | # pass | ||||
# Start training | # Start training | ||||
trainer.train(model, train_data, dev_data) | |||||
trainer.train() | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | |||||
saver = ModelSaver("./save/saved_model.pkl") | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | |||||
# save pipeline | |||||
save_pipe(path) | |||||
print('pipe saved') | |||||
def save_pipe(path): | |||||
pipe = Pipeline(processors=[num_p, word_idxp, pos_idxp, seq_p, set_input_p]) | |||||
pipe.add_processor(ModelProcessor(model=model, batch_size=32)) | |||||
pipe.add_processor(label_toword_p) | |||||
torch.save(pipe, os.path.join(path, 'pipe.pkl')) | |||||
def test(path): | def test(path): | ||||
# Tester | # Tester | ||||
tester = Tester(**test_args.data, evaluator=ParserEvaluator(ignore_label)) | |||||
tester = Tester(**test_args.data) | |||||
# Model | # Model | ||||
model = BiaffineParser(**model_args.data) | model = BiaffineParser(**model_args.data) | ||||
@@ -333,13 +234,18 @@ def test(path): | |||||
print("Testing Test data") | print("Testing Test data") | ||||
tester.test(model, test_data) | tester.test(model, test_data) | ||||
def build_pipe(parser_pipe_path): | |||||
parser_pipe = torch.load(parser_pipe_path) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
import argparse | import argparse | ||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer', 'save']) | |||||
parser.add_argument('--path', type=str, default='') | parser.add_argument('--path', type=str, default='') | ||||
# parser.add_argument('--dst', type=str, default='') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if args.mode == 'train': | if args.mode == 'train': | ||||
train(args.path) | train(args.path) | ||||
@@ -347,6 +253,12 @@ if __name__ == "__main__": | |||||
test(args.path) | test(args.path) | ||||
elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
pass | pass | ||||
# elif args.mode == 'save': | |||||
# print(f'save model from {args.path} to {args.dst}') | |||||
# save_model(args.path, args.dst) | |||||
# load_path = os.path.dirname(args.dst) | |||||
# print(f'save pipeline in {load_path}') | |||||
# build(load_path) | |||||
else: | else: | ||||
print('no mode specified for model!') | print('no mode specified for model!') | ||||
parser.print_help() | parser.print_help() |
@@ -6,6 +6,13 @@ from fastNLP.io.dataset_loader import DataSetLoader | |||||
def cut_long_sentence(sent, max_sample_length=200): | def cut_long_sentence(sent, max_sample_length=200): | ||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | sent_no_space = sent.replace(' ', '') | ||||
cutted_sentence = [] | cutted_sentence = [] | ||||
if len(sent_no_space) > max_sample_length: | if len(sent_no_space) > max_sample_length: | ||||
@@ -127,12 +134,26 @@ class POSCWSReader(DataSetLoader): | |||||
return dataset | return dataset | ||||
class ConlluCWSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
class ConllCWSReader(object): | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
def load(self, path, cut_long_sent=False): | def load(self, path, cut_long_sent=False): | ||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -150,10 +171,10 @@ class ConlluCWSReader(object): | |||||
ds = DataSet() | ds = DataSet() | ||||
for sample in datalist: | for sample in datalist: | ||||
# print(sample) | # print(sample) | ||||
res = self.get_one(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | if res is None: | ||||
continue | continue | ||||
line = ' '.join(res) | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | if cut_long_sent: | ||||
sents = cut_long_sentence(line) | sents = cut_long_sentence(line) | ||||
else: | else: | ||||
@@ -163,7 +184,7 @@ class ConlluCWSReader(object): | |||||
return ds | return ds | ||||
def get_one(self, sample): | |||||
def get_char_lst(self, sample): | |||||
if len(sample)==0: | if len(sample)==0: | ||||
return None | return None | ||||
text = [] | text = [] | ||||
@@ -9,7 +9,7 @@ from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
class CWSBiLSTMEncoder(BaseModel): | class CWSBiLSTMEncoder(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1): | |||||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1): | |||||
super().__init__() | super().__init__() | ||||
self.input_size = 0 | self.input_size = 0 | ||||
@@ -65,9 +65,10 @@ class CWSBiLSTMEncoder(BaseModel): | |||||
x_tensor = self.char_embedding(chars) | x_tensor = self.char_embedding(chars) | ||||
if not bigrams is None: | |||||
if hasattr(self, 'bigram_embedding'): | |||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) | ||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) | ||||
x_tensor = self.embedding_drop(x_tensor) | |||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) | ||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) | ||||
@@ -120,10 +121,24 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | from fastNLP.modules.decoder.CRF import ConditionalRandomField | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
class CWSBiLSTMCRF(BaseModel): | class CWSBiLSTMCRF(BaseModel): | ||||
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | ||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||||
hidden_size=200, bidirectional=True, embed_drop_p=0.2, num_layers=1, tag_size=4): | |||||
""" | |||||
默认使用BMES的标注方式 | |||||
:param vocab_num: | |||||
:param embed_dim: | |||||
:param bigram_vocab_num: | |||||
:param bigram_embed_dim: | |||||
:param num_bigram_per_char: | |||||
:param hidden_size: | |||||
:param bidirectional: | |||||
:param embed_drop_p: | |||||
:param num_layers: | |||||
:param tag_size: | |||||
""" | |||||
super(CWSBiLSTMCRF, self).__init__() | super(CWSBiLSTMCRF, self).__init__() | ||||
self.tag_size = tag_size | self.tag_size = tag_size | ||||
@@ -133,10 +148,12 @@ class CWSBiLSTMCRF(BaseModel): | |||||
size_layer = [hidden_size, 200, tag_size] | size_layer = [hidden_size, 200, tag_size] | ||||
self.decoder_model = MLP(size_layer) | self.decoder_model = MLP(size_layer) | ||||
self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||||
allowed_trans = allowed_transitions({0:'b', 1:'m', 2:'e', 3:'s'}, encoding_type='bmes') | |||||
self.crf = ConditionalRandomField(num_tags=tag_size, include_start_end_trans=False, | |||||
allowed_transitions=allowed_trans) | |||||
def forward(self, chars, tags, seq_lens, bigrams=None): | |||||
def forward(self, chars, target, seq_lens, bigrams=None): | |||||
device = self.parameters().__next__().device | device = self.parameters().__next__().device | ||||
chars = chars.to(device).long() | chars = chars.to(device).long() | ||||
if not bigrams is None: | if not bigrams is None: | ||||
@@ -147,7 +164,7 @@ class CWSBiLSTMCRF(BaseModel): | |||||
masks = seq_lens_to_mask(seq_lens) | masks = seq_lens_to_mask(seq_lens) | ||||
feats = self.encoder_model(chars, bigrams, seq_lens) | feats = self.encoder_model(chars, bigrams, seq_lens) | ||||
feats = self.decoder_model(feats) | feats = self.decoder_model(feats) | ||||
losses = self.crf(feats, tags, masks) | |||||
losses = self.crf(feats, target, masks) | |||||
pred_dict = {} | pred_dict = {} | ||||
pred_dict['seq_lens'] = seq_lens | pred_dict['seq_lens'] = seq_lens | ||||
@@ -168,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel): | |||||
feats = self.decoder_model(feats) | feats = self.decoder_model(feats) | ||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | probs = self.crf.viterbi_decode(feats, masks, get_score=False) | ||||
return {'pred_tags': probs} | |||||
return {'pred': probs, 'seq_lens':seq_lens} | |||||
@@ -1,17 +1,18 @@ | |||||
import re | import re | ||||
from fastNLP.core.field import SeqLabelField | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | from reproduction.chinese_word_segment.process.span_converter import SpanConverter | ||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | ||||
class SpeicalSpanProcessor(Processor): | class SpeicalSpanProcessor(Processor): | ||||
# 这个类会将句子中的special span转换为对应的内容。 | |||||
""" | |||||
将DataSet中field_name使用span_converter替换掉。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | super(SpeicalSpanProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -20,11 +21,12 @@ class SpeicalSpanProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
for span_converter in self.span_converters: | for span_converter in self.span_converters: | ||||
sentence = span_converter.find_certain_span_and_replace(sentence) | sentence = span_converter.find_certain_span_and_replace(sentence) | ||||
ins[self.new_added_field_name] = sentence | |||||
return sentence | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
@@ -34,17 +36,22 @@ class SpeicalSpanProcessor(Processor): | |||||
self.span_converters.append(converter) | self.span_converters.append(converter) | ||||
class CWSCharSegProcessor(Processor): | class CWSCharSegProcessor(Processor): | ||||
""" | |||||
将DataSet中field_name这个field分成一个个的汉字,即原来可能为"复旦大学 fudan", 分成['复', '旦', '大', '学', | |||||
' ', 'f', 'u', ...] | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | super(CWSCharSegProcessor, self).__init__(field_name, new_added_field_name) | ||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
chars = self._split_sent_into_chars(sentence) | chars = self._split_sent_into_chars(sentence) | ||||
ins[self.new_added_field_name] = chars | |||||
return chars | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
@@ -73,6 +80,10 @@ class CWSCharSegProcessor(Processor): | |||||
class CWSTagProcessor(Processor): | class CWSTagProcessor(Processor): | ||||
""" | |||||
为分词生成tag。该class为Base class。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -107,18 +118,22 @@ class CWSTagProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | sentence = ins[self.field_name] | ||||
tag_list = self._generate_tag(sentence) | tag_list = self._generate_tag(sentence) | ||||
ins[self.new_added_field_name] = tag_list | |||||
dataset.set_target(**{self.new_added_field_name:True}) | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return tag_list | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
dataset.set_target(self.new_added_field_name) | |||||
return dataset | return dataset | ||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class CWSBMESTagProcessor(CWSTagProcessor): | class CWSBMESTagProcessor(CWSTagProcessor): | ||||
""" | |||||
通过DataSet中的field_name这个field生成相应的BMES的tag。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -137,6 +152,10 @@ class CWSBMESTagProcessor(CWSTagProcessor): | |||||
return tag_list | return tag_list | ||||
class CWSSegAppTagProcessor(CWSTagProcessor): | class CWSSegAppTagProcessor(CWSTagProcessor): | ||||
""" | |||||
通过DataSet中的field_name这个field生成相应的SegApp的tag。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
@@ -151,6 +170,10 @@ class CWSSegAppTagProcessor(CWSTagProcessor): | |||||
class BigramProcessor(Processor): | class BigramProcessor(Processor): | ||||
""" | |||||
这是生成bigram的基类。 | |||||
""" | |||||
def __init__(self, field_name, new_added_fielf_name=None): | def __init__(self, field_name, new_added_fielf_name=None): | ||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | ||||
@@ -158,22 +181,31 @@ class BigramProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
characters = ins[self.field_name] | characters = ins[self.field_name] | ||||
bigrams = self._generate_bigram(characters) | bigrams = self._generate_bigram(characters) | ||||
ins[self.new_added_field_name] = bigrams | |||||
return bigrams | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | return dataset | ||||
def _generate_bigram(self, characters): | def _generate_bigram(self, characters): | ||||
pass | pass | ||||
class Pre2Post2BigramProcessor(BigramProcessor): | class Pre2Post2BigramProcessor(BigramProcessor): | ||||
def __init__(self, field_name, new_added_fielf_name=None): | |||||
""" | |||||
该bigram processor生成bigram的方式如下 | |||||
原汉字list为l = ['a', 'b', 'c'],会被padding为L=['SOS', 'SOS', 'a', 'b', 'c', 'EOS', 'EOS'],生成bigram list为 | |||||
[L[idx-2], L[idx-1], L[idx+1], L[idx+2], L[idx-2]L[idx], L[idx-1]L[idx], L[idx]L[idx+1], L[idx]L[idx+2], ....] | |||||
即每个汉字,会有八个bigram, 对于上例中'a'的bigram为 | |||||
['SOS', 'SOS', 'b', 'c', 'SOSa', 'SOSa', 'ab', 'ac'] | |||||
返回的bigram是一个list,但其实每8个元素是一个汉字的bigram信息。 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name=None): | |||||
super(BigramProcessor, self).__init__(field_name, new_added_fielf_name) | |||||
super(BigramProcessor, self).__init__(field_name, new_added_field_name) | |||||
def _generate_bigram(self, characters): | def _generate_bigram(self, characters): | ||||
bigrams = [] | bigrams = [] | ||||
@@ -197,20 +229,116 @@ class Pre2Post2BigramProcessor(BigramProcessor): | |||||
# 这里需要建立vocabulary了,但是遇到了以下的问题 | # 这里需要建立vocabulary了,但是遇到了以下的问题 | ||||
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | # (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用 | ||||
# Processor了 | # Processor了 | ||||
# TODO 如何将建立vocab和index这两步统一了? | |||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose =verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets)==0 and not hasattr(self,'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets)!=0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose | |||||
class VocabProcessor(Processor): | class VocabProcessor(Processor): | ||||
def __init__(self, field_name, min_count=1, max_vocab_size=None): | |||||
def __init__(self, field_name, min_freq=1, max_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | super(VocabProcessor, self).__init__(field_name, None) | ||||
self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) | |||||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||||
def process(self, *datasets): | def process(self, *datasets): | ||||
for dataset in datasets: | for dataset in datasets: | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
@@ -220,19 +348,6 @@ class VocabProcessor(Processor): | |||||
return len(self.vocab) | return len(self.vocab) | ||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset._set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset | |||||
class SegApp2OutputProcessor(Processor): | class SegApp2OutputProcessor(Processor): | ||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | ||||
super(SegApp2OutputProcessor, self).__init__(None, None) | super(SegApp2OutputProcessor, self).__init__(None, None) | ||||
@@ -258,7 +373,32 @@ class SegApp2OutputProcessor(Processor): | |||||
class BMES2OutputProcessor(Processor): | class BMES2OutputProcessor(Processor): | ||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
""" | |||||
按照BMES标注方式推测生成的tag。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||||
| | next_B | next_M | next_E | next_S | end | | |||||
|:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| | |||||
| start | 合法 | next_M=B | next_E=S | 合法 | - | | |||||
| cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | | |||||
| cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | | |||||
| cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
| cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | | |||||
举例: | |||||
prediction为BSEMS,会被认为是SSSSS. | |||||
""" | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred', new_added_field_name='output', | |||||
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): | |||||
""" | |||||
:param chars_field_name: character所对应的field | |||||
:param tag_field_name: 预测对应的field | |||||
:param new_added_field_name: 转换后的内容所在field | |||||
:param b_idx: int, Begin标签所对应的tag idx. | |||||
:param m_idx: int, Middle标签所对应的tag idx. | |||||
:param e_idx: int, End标签所对应的tag idx. | |||||
:param s_idx: int, Single标签所对应的tag idx | |||||
""" | |||||
super(BMES2OutputProcessor, self).__init__(None, None) | super(BMES2OutputProcessor, self).__init__(None, None) | ||||
self.chars_field_name = chars_field_name | self.chars_field_name = chars_field_name | ||||
@@ -266,19 +406,84 @@ class BMES2OutputProcessor(Processor): | |||||
self.new_added_field_name = new_added_field_name | self.new_added_field_name = new_added_field_name | ||||
self.b_idx = b_idx | |||||
self.m_idx = m_idx | |||||
self.e_idx = e_idx | |||||
self.s_idx = s_idx | |||||
# 还原init处介绍的矩阵 | |||||
self._valida_matrix = { | |||||
-1: [(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1)], # magic start idx | |||||
self.b_idx:[(0, self.s_idx), (-1, -1), (-1, -1), (0, self.s_idx), (0, self.s_idx)], | |||||
self.m_idx:[(0, self.e_idx), (-1, -1), (-1, -1), (0, self.e_idx), (0, self.e_idx)], | |||||
self.e_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
self.s_idx:[(-1, -1), (1, self.b_idx), (1, self.s_idx), (-1, -1), (-1, -1)], | |||||
} | |||||
def _validate_tags(self, tags): | |||||
""" | |||||
给定一个tag的List,返回合法tag | |||||
:param tags: Tensor, shape: (seq_len, ) | |||||
:return: 返回修改为合法tag的list | |||||
""" | |||||
assert len(tags)!=0 | |||||
padded_tags = [-1, *tags, -1] | |||||
for idx in range(len(padded_tags)-1): | |||||
cur_tag = padded_tags[idx] | |||||
if cur_tag not in self._valida_matrix: | |||||
cur_tag = self.s_idx | |||||
if padded_tags[idx+1] not in self._valida_matrix: | |||||
padded_tags[idx+1] = self.s_idx | |||||
next_tag = padded_tags[idx+1] | |||||
shift_tag = self._valida_matrix[cur_tag][next_tag] | |||||
if shift_tag[0]!=-1: | |||||
padded_tags[idx+shift_tag[0]] = shift_tag[1] | |||||
return padded_tags[1:-1] | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | |||||
def inner_proc(ins): | |||||
pred_tags = ins[self.tag_field_name] | pred_tags = ins[self.tag_field_name] | ||||
pred_tags = self._validate_tags(pred_tags) | |||||
chars = ins[self.chars_field_name] | chars = ins[self.chars_field_name] | ||||
words = [] | words = [] | ||||
start_idx = 0 | start_idx = 0 | ||||
for idx, tag in enumerate(pred_tags): | for idx, tag in enumerate(pred_tags): | ||||
if tag==3: | |||||
# 当前没有考虑将原文替换回去 | |||||
if tag==self.s_idx: | |||||
words.extend(chars[start_idx:idx+1]) | words.extend(chars[start_idx:idx+1]) | ||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
elif tag==2: | |||||
elif tag==self.e_idx: | |||||
words.append(''.join(chars[start_idx:idx+1])) | words.append(''.join(chars[start_idx:idx+1])) | ||||
start_idx = idx + 1 | start_idx = idx + 1 | ||||
ins[self.new_added_field_name] = ' '.join(words) | |||||
return ' '.join(words) | |||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name) | |||||
class InputTargetProcessor(Processor): | |||||
def __init__(self, input_fields, target_fields): | |||||
""" | |||||
对DataSet操作,将input_fields中的field设置为input,target_fields的中field设置为target | |||||
:param input_fields: List[str], 设置为input_field的field_name。如果为None,则不将任何field设置为target。 | |||||
:param target_fields: List[str], 设置为target_field的field_name。 如果为None,则不将任何field设置为target。 | |||||
""" | |||||
super(InputTargetProcessor, self).__init__(None, None) | |||||
if input_fields is not None and not isinstance(input_fields, list): | |||||
raise TypeError("input_fields should be List[str], not {}.".format(type(input_fields))) | |||||
else: | |||||
self.input_fields = input_fields | |||||
if target_fields is not None and not isinstance(target_fields, list): | |||||
raise TypeError("target_fiels should be List[str], not{}.".format(type(target_fields))) | |||||
else: | |||||
self.target_fields = target_fields | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
if self.input_fields is not None: | |||||
for field in self.input_fields: | |||||
dataset.set_input(field) | |||||
if self.target_fields is not None: | |||||
for field in self.target_fields: | |||||
dataset.set_target(field) |
@@ -4,6 +4,7 @@ from collections import Counter | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
class CombineWordAndPosProcessor(Processor): | class CombineWordAndPosProcessor(Processor): | ||||
def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
super(CombineWordAndPosProcessor, self).__init__(None, None) | super(CombineWordAndPosProcessor, self).__init__(None, None) | ||||
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||||
return dataset | return dataset | ||||
class PosOutputStrProcessor(Processor): | class PosOutputStrProcessor(Processor): | ||||
def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
super(PosOutputStrProcessor, self).__init__(None, None) | super(PosOutputStrProcessor, self).__init__(None, None) |
@@ -24,8 +24,8 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
return cutted_sentence | return cutted_sentence | ||||
class ConlluPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||||
class ConllPOSReader(object): | |||||
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。 | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -70,6 +70,70 @@ class ConlluPOSReader(object): | |||||
return ds | return ds | ||||
class ZhConllPOSReader(object): | |||||
# 中文colln格式reader | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path): | |||||
""" | |||||
返回的DataSet, 包含以下的field | |||||
words:list of str, | |||||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_one(sample) | |||||
if res is None: | |||||
continue | |||||
char_seq = [] | |||||
pos_seq = [] | |||||
for word, tag in zip(res[0], res[1]): | |||||
char_seq.extend(list(word)) | |||||
if len(word)==1: | |||||
pos_seq.append('S-{}'.format(tag)) | |||||
elif len(word)>1: | |||||
pos_seq.append('B-{}'.format(tag)) | |||||
for _ in range(len(word)-2): | |||||
pos_seq.append('M-{}'.format(tag)) | |||||
pos_seq.append('E-{}'.format(tag)) | |||||
else: | |||||
raise ValueError("Zero length of word detected.") | |||||
ds.append(Instance(words=char_seq, | |||||
tag=pos_seq)) | |||||
return ds | |||||
def get_one(self, sample): | def get_one(self, sample): | ||||
if len(sample)==0: | if len(sample)==0: | ||||
return None | return None | ||||
@@ -84,6 +148,6 @@ class ConlluPOSReader(object): | |||||
return text, pos_tags | return text, pos_tags | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
reader = ConlluPOSReader() | |||||
reader = ZhConllPOSReader() | |||||
d = reader.load('/home/hyan/train.conllx') | d = reader.load('/home/hyan/train.conllx') | ||||
print('reader') | |||||
print(d) |
@@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' | |||||
[model] | [model] | ||||
rnn_hidden_units = 300 | rnn_hidden_units = 300 | ||||
word_emb_dim = 300 | |||||
word_emb_dim = 100 | |||||
dropout = 0.5 | dropout = 0.5 | ||||
use_crf = true | use_crf = true | ||||
print_every_step = 10 | print_every_step = 10 | ||||
@@ -0,0 +1,113 @@ | |||||
import argparse | |||||
import os | |||||
import pickle | |||||
import sys | |||||
import torch | |||||
# in order to run fastNLP without installation | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import SeqLenProcessor | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||||
cfgfile = './pos_tag.cfg' | |||||
pickle_path = "save" | |||||
def load_tencent_embed(embed_path, word2id): | |||||
hit = 0 | |||||
with open(embed_path, "rb") as f: | |||||
embed_dict = pickle.load(f) | |||||
embedding_tensor = torch.randn(len(word2id), 200) | |||||
for key in word2id: | |||||
if key in embed_dict: | |||||
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) | |||||
hit += 1 | |||||
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) | |||||
return embedding_tensor | |||||
def train(checkpoint=None): | |||||
# load config | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | |||||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||||
print(dataset) | |||||
print("dataset transformed") | |||||
dataset.rename_field("tag", "truth") | |||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||||
tag_proc = VocabIndexerProcessor("truth") | |||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||||
vocab_proc(dataset) | |||||
tag_proc(dataset) | |||||
seq_len_proc(dataset) | |||||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
dataset.set_target("truth", "word_seq_origin_len") | |||||
print("processors defined") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||||
# define a model | |||||
if checkpoint is None: | |||||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||||
pre_trained = None | |||||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) | |||||
print(model) | |||||
else: | |||||
model = torch.load(checkpoint) | |||||
# call trainer to train | |||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dataset, metric_key="f", | |||||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") | |||||
trainer.train(load_best_model=True) | |||||
# save model & pipeline | |||||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||||
torch.save(save_dict, "model_pp.pkl") | |||||
print("pipeline saved") | |||||
torch.save(model, "./save/best_model.pkl") | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||||
args = parser.parse_args() | |||||
if args.restart is True: | |||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||||
if args.checkpoint is None: | |||||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||||
train(args.checkpoint) | |||||
else: | |||||
# 一次训练 python train_pos_tag.py | |||||
train() |
@@ -0,0 +1,25 @@ | |||||
import pickle | |||||
def load_embed(embed_path): | |||||
embed_dict = {} | |||||
with open(embed_path, "r", encoding="utf-8") as f: | |||||
for line in f: | |||||
tokens = line.split(" ") | |||||
if len(tokens) <= 5: | |||||
continue | |||||
key = tokens[0] | |||||
if len(key) == 1: | |||||
value = [float(x) for x in tokens[1:]] | |||||
embed_dict[key] = value | |||||
return embed_dict | |||||
if __name__ == "__main__": | |||||
embed_dict = load_embed("/home/zyfeng/data/small.txt") | |||||
print(embed_dict.keys()) | |||||
with open("./char_tencent_embedding.pkl", "wb") as f: | |||||
pickle.dump(embed_dict, f) | |||||
print("finished") |
@@ -0,0 +1,44 @@ | |||||
import unittest | |||||
import numpy as np | |||||
from fastNLP.core.callback import EchoCallback | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
class TestCallback(unittest.TestCase): | |||||
def test_case(self): | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x") | |||||
data_set.set_target("y") | |||||
model = NaiveClassifier(2, 1) | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=1, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[EchoCallback()]) | |||||
trainer.train() |
@@ -197,4 +197,4 @@ class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | def test__repr__(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
for iter in ds: | for iter in ds: | ||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}") |
@@ -4,6 +4,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.metrics import AccuracyMetric | from fastNLP.core.metrics import AccuracyMetric | ||||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||||
from fastNLP.core.metrics import pred_topk, accuracy_topk | from fastNLP.core.metrics import pred_topk, accuracy_topk | ||||
@@ -132,6 +133,235 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
class SpanF1PreRecMetric(unittest.TestCase): | |||||
def test_case1(self): | |||||
from fastNLP.core.metrics import bmes_tag_to_spans | |||||
from fastNLP.core.metrics import bio_tag_to_spans | |||||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||||
expect_bmes_res = set() | |||||
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), | |||||
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) | |||||
expect_bio_res = set() | |||||
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), | |||||
('6', (4, 4)), ('7', (6, 6))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||||
# for i in range(1000): | |||||
# strs = list(map(str, np.random.randint(100, size=1000))) | |||||
# bmes = list('bmes'.upper()) | |||||
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] | |||||
# bio = list('bio'.upper()) | |||||
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||||
def test_case2(self): | |||||
# 测试不带label的 | |||||
from fastNLP.core.metrics import bmes_tag_to_spans | |||||
from fastNLP.core.metrics import bio_tag_to_spans | |||||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||||
expect_bmes_res = set() | |||||
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) | |||||
expect_bio_res = set() | |||||
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | |||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||||
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans | |||||
# for i in range(1000): | |||||
# bmes = list('bmes'.upper()) | |||||
# bmes_strs = np.random.choice(bmes, size=1000) | |||||
# bio = list('bio'.upper()) | |||||
# bio_strs = np.random.choice(bio, size=100) | |||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||||
def tese_case3(self): | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from collections import Counter | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
# 与allennlp测试能否正确计算f metric | |||||
# | |||||
def generate_allen_tags(encoding_type, number_labels=4): | |||||
vocab = {} | |||||
for i in range(number_labels): | |||||
label = str(i) | |||||
for tag in encoding_type: | |||||
if tag == 'O': | |||||
if tag not in vocab: | |||||
vocab['O'] = len(vocab) + 1 | |||||
continue | |||||
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||||
return vocab | |||||
number_labels = 4 | |||||
# bio tag | |||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||||
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
bio_sequence = torch.FloatTensor( | |||||
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||||
0.0470, 0.0971], | |||||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||||
0.7987, -0.3970], | |||||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||||
0.6880, 1.4348], | |||||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||||
-1.6876, -0.8917], | |||||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||||
1.4217, 0.2622]], | |||||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||||
1.3592, -0.8973], | |||||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||||
-0.4025, -0.3417], | |||||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||||
0.2861, -0.3966], | |||||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||||
0.0213, 1.4777], | |||||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||||
1.3024, 0.2001]]] | |||||
) | |||||
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||||
[5., 6., 8., 6., 0.]]) | |||||
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) | |||||
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, | |||||
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, | |||||
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||||
'f': 0.12499999999994846} | |||||
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||||
#bmes tag | |||||
bmes_sequence = torch.FloatTensor( | |||||
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||||
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||||
-0.3505, -0.6002], | |||||
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, | |||||
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, | |||||
0.4439, 0.2514], | |||||
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, | |||||
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, | |||||
0.5802, -0.0516], | |||||
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, | |||||
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, | |||||
-0.9242, -0.0333], | |||||
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||||
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||||
-0.3779, -0.3195]], | |||||
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||||
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||||
-0.1103, 0.4417], | |||||
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, | |||||
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, | |||||
-0.6386, 0.2797], | |||||
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, | |||||
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, | |||||
1.1237, -1.5385], | |||||
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, | |||||
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, | |||||
-0.0591, -0.2461], | |||||
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, | |||||
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||||
-0.7344, -1.2046]]] | |||||
) | |||||
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||||
[ 6., 15., 6., 15., 5.]]) | |||||
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||||
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||||
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||||
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||||
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||||
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||||
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||||
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||||
'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||||
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||||
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||||
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||||
# from allennlp.training.metrics import SpanBasedF1Measure | |||||
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, | |||||
# non_padded_namespaces=['tags']) | |||||
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') | |||||
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) | |||||
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) | |||||
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) | |||||
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||||
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) | |||||
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
# | |||||
# def convert_allen_res_to_fastnlp_res(metric_result): | |||||
# allen_result = {} | |||||
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} | |||||
# for key, value in metric_result.items(): | |||||
# if key in key_map: | |||||
# key = key_map[key] | |||||
# else: | |||||
# label = key.split('-')[-1] | |||||
# if key.startswith('f1'): | |||||
# key = 'f-{}'.format(label) | |||||
# else: | |||||
# key = '{}-{}'.format(key[:3], label) | |||||
# allen_result[key] = value | |||||
# return allen_result | |||||
# | |||||
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) | |||||
# # print(fastnlp_bio_metric.get_metric()) | |||||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), | |||||
# fastnlp_bio_metric.get_metric()) | |||||
# | |||||
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) | |||||
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') | |||||
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||||
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||||
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||||
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) | |||||
# bmes_target = torch.randint(4 * number_labels, size=(2, 20)) | |||||
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) | |||||
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||||
# | |||||
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) | |||||
# # print(fastnlp_bmes_metric.get_metric()) | |||||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||||
# fastnlp_bmes_metric.get_metric()) | |||||
class TestBMESF1PreRecMetric(unittest.TestCase): | |||||
def test_case1(self): | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
pred = torch.randn(2, 4, 4) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': pred} | |||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
metric.get_metric() | |||||
def test_case2(self): | |||||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||||
seq_lens = torch.LongTensor([4, 2]) | |||||
target = torch.LongTensor([[0, 1, 2, 3], | |||||
[3, 3, 0, 0]]) | |||||
pred_dict = {'pred': target} | |||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
metric = BMESF1PreRecMetric() | |||||
metric(pred_dict, target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | |||||
class TestUsefulFunctions(unittest.TestCase): | class TestUsefulFunctions(unittest.TestCase): | ||||
# 测试metrics.py中一些看上去挺有用的函数 | # 测试metrics.py中一些看上去挺有用的函数 | ||||
@@ -10,6 +10,8 @@ data_file = """ | |||||
4 will _ AUX MD _ 6 aux _ _ | 4 will _ AUX MD _ 6 aux _ _ | ||||
5 be _ VERB VB _ 6 cop _ _ | 5 be _ VERB VB _ 6 cop _ _ | ||||
6 payable _ ADJ JJ _ 0 root _ _ | 6 payable _ ADJ JJ _ 0 root _ _ | ||||
7 mask _ ADJ JJ _ 6 punct _ _ | |||||
8 mask _ ADJ JJ _ 6 punct _ _ | |||||
9 cents _ NOUN NNS _ 4 nmod _ _ | 9 cents _ NOUN NNS _ 4 nmod _ _ | ||||
10 from _ ADP IN _ 12 case _ _ | 10 from _ ADP IN _ 12 case _ _ | ||||
11 seven _ NUM CD _ 12 nummod _ _ | 11 seven _ NUM CD _ 12 nummod _ _ | ||||
@@ -58,13 +60,13 @@ def init_data(): | |||||
data.append(line) | data.append(line) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: ['<st>']+list(x[name])+['<ed>'], new_field_name=name) | |||||
ds.apply(lambda x: ['<st>']+list(x[name]), new_field_name=name) | |||||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | ds.apply(lambda x: v[name].add_word_lst(x[name])) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | for name in ['word_seq', 'pos_seq', 'label_true']: | ||||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ||||
ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true') | |||||
ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') | |||||
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | ||||
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | ||||
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | ||||
@@ -75,8 +77,11 @@ class TestBiaffineParser(unittest.TestCase): | |||||
ds, v1, v2, v3 = init_data() | ds, v1, v2, v3 = init_data() | ||||
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30, | ||||
pos_vocab_size=len(v2), pos_emb_dim=30, | pos_vocab_size=len(v2), pos_emb_dim=30, | ||||
num_label=len(v3)) | |||||
num_label=len(v3), use_var_lstm=True) | |||||
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds, | ||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | ||||
n_epochs=10, use_cuda=False, use_tqdm=False) | n_epochs=10, use_cuda=False, use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,104 @@ | |||||
import unittest | |||||
class TestCRF(unittest.TestCase): | |||||
def test_case1(self): | |||||
# 检查allowed_transitions()能否正确使用 | |||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||||
allowed_transitions(id2label) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx:label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
def test_case2(self): | |||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||||
pass | |||||
# import torch | |||||
# from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask | |||||
# | |||||
# labels = ['O'] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BI': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||||
# include_start_end_transitions=False) | |||||
# batch_size = 3 | |||||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||||
# trans_m = allen_CRF.transitions | |||||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||||
# seq_lens[-1] = 20 | |||||
# mask = seq_len_to_byte_mask(seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) | |||||
# fast_CRF.trans_m = trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# # score equal | |||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||||
# # seq equal | |||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||||
# | |||||
# | |||||
# labels = [] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BMES': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||||
# include_start_end_transitions=False) | |||||
# batch_size = 3 | |||||
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() | |||||
# trans_m = allen_CRF.transitions | |||||
# seq_lens = torch.randint(1, 20, size=(batch_size,)) | |||||
# seq_lens[-1] = 20 | |||||
# mask = seq_len_to_byte_mask(seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
# encoding_type='BMES')) | |||||
# fast_CRF.trans_m = trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) | |||||
# # score equal | |||||
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) | |||||
# # seq equal | |||||
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) | |||||
@@ -0,0 +1,12 @@ | |||||
# fastNLP 教程 | |||||
### 上手教程 Quick Start | |||||
- 一分钟上手:`fastnlp_1min_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_1min_tutorial.ipynb) | |||||
- 十分钟上手:`fastnlp_10min_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_10min_tutorial.ipynb) | |||||
### 进阶教程 Advanced Tutorial | |||||
- `fastnlp_advanced_tutorial/advance_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb) | |||||
### 开发者指南 Developer Guide | |||||
- `tutorial_for_developer.md` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/tutorial_for_developer.md) |
@@ -4,12 +4,29 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"fastNLP上手教程\n", | |||||
"fastNLP10 分钟上手教程\n", | |||||
"-------\n", | "-------\n", | ||||
"\n", | "\n", | ||||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | "fastNLP提供方便的数据预处理,训练和测试模型的功能" | ||||
] | ] | ||||
}, | }, | ||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"如果您还没有通过pip安装fastNLP,可以执行下面的操作加载当前模块" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import sys\n", | |||||
"sys.path.append(\"../\")" | |||||
] | |||||
}, | |||||
{ | { | ||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
@@ -24,21 +41,14 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": 6, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"8529" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"77\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -47,27 +57,23 @@ | |||||
"from fastNLP import Instance\n", | "from fastNLP import Instance\n", | ||||
"\n", | "\n", | ||||
"# 从csv读取数据到DataSet\n", | "# 从csv读取数据到DataSet\n", | ||||
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", | |||||
"dataset = DataSet.read_csv('sample_data/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", | |||||
"print(len(dataset))" | "print(len(dataset))" | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 10, | |||||
"execution_count": 7, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||||
"'label': 1 type=str}\n", | |||||
"{'raw_sentence': The plot is romantic comedy boilerplate from start to finish . type=str,\n", | |||||
"'label': 2 type=str}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -91,16 +97,17 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 11, | |||||
"execution_count": 8, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'raw_sentence': fake data,\n'label': 0}" | |||||
"{'raw_sentence': fake data type=str,\n", | |||||
"'label': 0 type=str}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 11, | |||||
"execution_count": 8, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | "output_type": "execute_result" | ||||
} | } | ||||
@@ -121,21 +128,15 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 12, | |||||
"execution_count": 9, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||||
"'label': 1 type=str}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -147,21 +148,15 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 13, | |||||
"execution_count": 10, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||||
"'label': 1 type=int}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -173,21 +168,16 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 14, | |||||
"execution_count": 11, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||||
"'label': 1 type=int,\n", | |||||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'] type=list}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -201,21 +191,17 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 15, | |||||
"execution_count": 12, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||||
"'label': 1 type=int,\n", | |||||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'] type=list,\n", | |||||
"'seq_len': 37 type=int}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -235,25 +221,19 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 16, | |||||
"execution_count": 13, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"8358" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"77\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"# 删除低于某个长度的词语\n", | |||||
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", | "dataset.drop(lambda x: x['seq_len'] <= 3)\n", | ||||
"print(len(dataset))" | "print(len(dataset))" | ||||
] | ] | ||||
@@ -269,7 +249,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 17, | |||||
"execution_count": 14, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -283,35 +263,15 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 18, | |||||
"execution_count": 15, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"5851" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"2507" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"54\n", | |||||
"23\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -335,21 +295,17 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 19, | |||||
"execution_count": 16, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"{'raw_sentence': the project 's filmmakers forgot to include anything even halfway scary as they poorly rejigger fatal attraction into a high school setting .,\n'label': 0,\n'words': [4, 423, 9, 316, 1, 8, 1, 312, 72, 1478, 885, 14, 86, 725, 1, 1913, 1431, 53, 5, 455, 736, 1, 2],\n'seq_len': 23}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"{'raw_sentence': a welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . type=str,\n", | |||||
"'label': 3 type=int,\n", | |||||
"'words': [4, 1, 1, 18, 1, 1, 13, 1, 1, 1, 8, 26, 1, 5, 35, 1, 11, 4, 1, 10, 1, 10, 1, 1, 1, 2] type=list,\n", | |||||
"'seq_len': 26 type=int}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -369,6 +325,23 @@ | |||||
"print(test_data[0])" | "print(test_data[0])" | ||||
] | ] | ||||
}, | }, | ||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具\n", | |||||
"from fastNLP.core.batch import Batch\n", | |||||
"from fastNLP.core.sampler import RandomSampler\n", | |||||
"\n", | |||||
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
" break" | |||||
] | |||||
}, | |||||
{ | { | ||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
@@ -379,16 +352,32 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 20, | |||||
"execution_count": 17, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(3459, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||||
"CNNText(\n", | |||||
" (embed): Embedding(\n", | |||||
" (embed): Embedding(59, 50, padding_idx=0)\n", | |||||
" (dropout): Dropout(p=0.0)\n", | |||||
" )\n", | |||||
" (conv_pool): ConvMaxpool(\n", | |||||
" (convs): ModuleList(\n", | |||||
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", | |||||
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", | |||||
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", | |||||
" )\n", | |||||
" )\n", | |||||
" (dropout): Dropout(p=0.1)\n", | |||||
" (fc): Linear(\n", | |||||
" (linear): Linear(in_features=12, out_features=5, bias=True)\n", | |||||
" )\n", | |||||
")" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 20, | |||||
"execution_count": 17, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | "output_type": "execute_result" | ||||
} | } | ||||
@@ -459,7 +448,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 21, | |||||
"execution_count": 18, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -496,7 +485,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 22, | |||||
"execution_count": 19, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -519,7 +508,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 23, | |||||
"execution_count": 20, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -528,149 +517,61 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 24, | |||||
"execution_count": 21, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"training epochs started 2018-12-07 14:11:31" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2019-01-12 17-07-51\n" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=915), HTML(value='')), layout=Layout(display=…" | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=10), HTML(value='')), layout=Layout(display='…" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 0, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:183/915. AccuracyMetric: acc=0.350367" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
"output_type": "display_data" | |||||
}, | }, | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"\r" | |||||
"Evaluation at Epoch 1/5. Step:2/10. AccuracyMetric: acc=0.425926\n", | |||||
"Evaluation at Epoch 2/5. Step:4/10. AccuracyMetric: acc=0.425926\n", | |||||
"Evaluation at Epoch 3/5. Step:6/10. AccuracyMetric: acc=0.611111\n", | |||||
"Evaluation at Epoch 4/5. Step:8/10. AccuracyMetric: acc=0.648148\n", | |||||
"Evaluation at Epoch 5/5. Step:10/10. AccuracyMetric: acc=0.703704\n", | |||||
"\n", | |||||
"In Epoch:5/Step:10, got best dev performance:AccuracyMetric: acc=0.703704\n", | |||||
"Reloaded the best model.\n" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:366/915. AccuracyMetric: acc=0.409332" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:549/915. AccuracyMetric: acc=0.572552" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:732/915. AccuracyMetric: acc=0.711331" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:915/915. AccuracyMetric: acc=0.801572" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'best_eval': {'AccuracyMetric': {'acc': 0.703704}},\n", | |||||
" 'best_epoch': 5,\n", | |||||
" 'best_step': 10,\n", | |||||
" 'seconds': 0.62}" | |||||
] | |||||
}, | |||||
"execution_count": 21, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"# 实例化Trainer,传入模型和数据,进行训练\n", | "# 实例化Trainer,传入模型和数据,进行训练\n", | ||||
"# 先在test_data拟合\n", | |||||
"# 先在test_data拟合(确保模型的实现是正确的)\n", | |||||
"copy_model = deepcopy(model)\n", | "copy_model = deepcopy(model)\n", | ||||
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", | "overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", | ||||
" loss=loss,\n", | " loss=loss,\n", | ||||
@@ -683,143 +584,43 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 25, | |||||
"execution_count": 22, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"training epochs started 2018-12-07 14:12:21" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 20]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2019-01-12 17-09-05\n" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=395), HTML(value='')), layout=Layout(display=…" | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 0, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:79/395. AccuracyMetric: acc=0.250043" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:158/395. AccuracyMetric: acc=0.280807" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:237/395. AccuracyMetric: acc=0.280978" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:316/395. AccuracyMetric: acc=0.285592" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:395/395. AccuracyMetric: acc=0.278927" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
"output_type": "display_data" | |||||
}, | }, | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"\r" | |||||
"Evaluation at Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.37037\n", | |||||
"Evaluation at Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.37037\n", | |||||
"Evaluation at Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.462963\n", | |||||
"Evaluation at Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.425926\n", | |||||
"Evaluation at Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.481481\n", | |||||
"\n", | |||||
"In Epoch:5/Step:5, got best dev performance:AccuracyMetric: acc=0.481481\n", | |||||
"Reloaded the best model.\n", | |||||
"Train finished!\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -837,35 +638,16 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 26, | |||||
"execution_count": 23, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"[tester] \nAccuracyMetric: acc=0.280636" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'AccuracyMetric': {'acc': 0.280636}}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.481481\n", | |||||
"{'AccuracyMetric': {'acc': 0.481481}}\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
@@ -879,6 +661,75 @@ | |||||
"print(acc)" | "print(acc)" | ||||
] | ] | ||||
}, | }, | ||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# In summary\n", | |||||
"\n", | |||||
"## fastNLP Trainer的伪代码逻辑\n", | |||||
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n", | |||||
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n", | |||||
" 通过\n", | |||||
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n", | |||||
" 通过\n", | |||||
" DataSet.set_target('label', flag=True)将'label'设置为target\n", | |||||
"### 2. 初始化模型\n", | |||||
" class Model(nn.Module):\n", | |||||
" def __init__(self):\n", | |||||
" xxx\n", | |||||
" def forward(self, word_seq1, word_seq2):\n", | |||||
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n", | |||||
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n", | |||||
" xxxx\n", | |||||
" # 输出必须是一个dict\n", | |||||
"### 3. Trainer的训练过程\n", | |||||
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n", | |||||
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n", | |||||
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n", | |||||
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n", | |||||
" 为了解决以上的问题,我们的loss提供映射机制\n", | |||||
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n", | |||||
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n", | |||||
" (3) 对于Metric是同理的\n", | |||||
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n", | |||||
" \n", | |||||
" \n", | |||||
"\n", | |||||
"## 一些问题.\n", | |||||
"### 1. DataSet中为什么需要设置input和target\n", | |||||
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n", | |||||
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n", | |||||
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n", | |||||
" (a)Model.forward的output\n", | |||||
" (b)被设置为target的field\n", | |||||
" \n", | |||||
"\n", | |||||
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n", | |||||
" (1.1) 构建模型过程中,\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" \n", | |||||
"\n", | |||||
"\n", | |||||
"### 1. 加载数据到DataSet\n", | |||||
"### 2. 使用apply操作对DataSet进行预处理\n", | |||||
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n", | |||||
"### 3. 构建模型\n", | |||||
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" (3.2) 模型的forward的output需要是dict类型的。\n", | |||||
" 建议将输出设置为{\"pred\": xx}.\n", | |||||
" \n" | |||||
] | |||||
}, | |||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | "execution_count": null, |
@@ -1,860 +0,0 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"fastNLP上手教程\n", | |||||
"-------\n", | |||||
"\n", | |||||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"DataSet & Instance\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", | |||||
"\n", | |||||
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"\n", | |||||
"# 从csv读取数据到DataSet\n", | |||||
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n", | |||||
"dataset = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')\n", | |||||
"print(dataset[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'raw_sentence': fake data,\n'label': 0}" | |||||
] | |||||
}, | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# DataSet.append(Instance)加入新数据\n", | |||||
"\n", | |||||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | |||||
"dataset[-1]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# DataSet.apply(func, new_field_name)对数据预处理\n", | |||||
"\n", | |||||
"# 将所有数字转为小写\n", | |||||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | |||||
"# label转int\n", | |||||
"dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n", | |||||
"# 使用空格分割句子\n", | |||||
"dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)\n", | |||||
"def split_sent(ins):\n", | |||||
" return ins['raw_sentence'].split()\n", | |||||
"dataset.apply(split_sent, new_field_name='words', is_input=True)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# DataSet.drop(func)筛除数据\n", | |||||
"# 删除低于某个长度的词语\n", | |||||
"dataset.drop(lambda x: len(x['words']) <= 3)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train size: " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" " | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"54" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Test size: " | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 分出测试集、训练集\n", | |||||
"\n", | |||||
"test_data, train_data = dataset.split(0.3)\n", | |||||
"print(\"Train size: \", len(test_data))\n", | |||||
"print(\"Test size: \", len(train_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Vocabulary\n", | |||||
"------\n", | |||||
"\n", | |||||
"fastNLP中的Vocabulary轻松构建词表,将词转成数字" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"{'raw_sentence': the plot is romantic comedy boilerplate from start to finish .,\n'label': 2,\n'label_seq': 2,\n'words': ['the', 'plot', 'is', 'romantic', 'comedy', 'boilerplate', 'from', 'start', 'to', 'finish', '.'],\n'word_seq': [2, 13, 9, 24, 25, 26, 15, 27, 11, 28, 3]}" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"# 构建词表, Vocabulary.add(word)\n", | |||||
"vocab = Vocabulary(min_freq=2)\n", | |||||
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", | |||||
"vocab.build_vocab()\n", | |||||
"\n", | |||||
"# index句子, Vocabulary.to_index(word)\n", | |||||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||||
"\n", | |||||
"\n", | |||||
"print(test_data[0])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": { | |||||
"scrolled": true | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),\n", | |||||
" list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],\n", | |||||
" dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,\n", | |||||
" 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,\n", | |||||
" 8, 1611, 16, 21, 1039, 1, 2],\n", | |||||
" [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0]])}\n", | |||||
"batch_y has: {'label_seq': tensor([3, 2])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset\n", | |||||
"from fastNLP.core.batch import Batch\n", | |||||
"from fastNLP.core.sampler import RandomSampler\n", | |||||
"\n", | |||||
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n", | |||||
"for batch_x, batch_y in batch_iterator:\n", | |||||
" print(\"batch_x has: \", batch_x)\n", | |||||
" print(\"batch_y has: \", batch_y)\n", | |||||
" break" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# Model\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": { | |||||
"collapsed": false | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(77, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||||
] | |||||
}, | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 定义一个简单的Pytorch模型\n", | |||||
"\n", | |||||
"from fastNLP.models import CNNText\n", | |||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | |||||
"model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Trainer & Tester\n", | |||||
"------\n", | |||||
"\n", | |||||
"使用fastNLP的Trainer训练模型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"from copy import deepcopy\n", | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"from fastNLP import AccuracyMetric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:07:20" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=20), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.296296" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.333333" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.555556" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.611111" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.481481" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.62963" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.685185" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.722222" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.777778" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 进行overfitting测试\n", | |||||
"copy_model = deepcopy(model)\n", | |||||
"overfit_trainer = Trainer(model=copy_model, \n", | |||||
" train_data=test_data, \n", | |||||
" dev_data=test_data,\n", | |||||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||||
" metrics=AccuracyMetric(),\n", | |||||
" n_epochs=10,\n", | |||||
" save_path=None)\n", | |||||
"overfit_trainer.train()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 14, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2018-12-07 14:08:10" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…" | |||||
] | |||||
}, | |||||
"execution_count": 0, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.037037" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.185185" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.240741" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Train finished!" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 实例化Trainer,传入模型和数据,进行训练\n", | |||||
"trainer = Trainer(model=model, \n", | |||||
" train_data=train_data, \n", | |||||
" dev_data=test_data,\n", | |||||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||||
" metrics=AccuracyMetric(),\n", | |||||
" n_epochs=5)\n", | |||||
"trainer.train()\n", | |||||
"print('Train finished!')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 15, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"[tester] \nAccuracyMetric: acc=0.240741" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Tester\n", | |||||
"\n", | |||||
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())\n", | |||||
"acc = tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# In summary\n", | |||||
"\n", | |||||
"## fastNLP Trainer的伪代码逻辑\n", | |||||
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n", | |||||
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n", | |||||
" 通过\n", | |||||
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n", | |||||
" 通过\n", | |||||
" DataSet.set_target('label', flag=True)将'label'设置为target\n", | |||||
"### 2. 初始化模型\n", | |||||
" class Model(nn.Module):\n", | |||||
" def __init__(self):\n", | |||||
" xxx\n", | |||||
" def forward(self, word_seq1, word_seq2):\n", | |||||
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n", | |||||
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n", | |||||
" xxxx\n", | |||||
" # 输出必须是一个dict\n", | |||||
"### 3. Trainer的训练过程\n", | |||||
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n", | |||||
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n", | |||||
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n", | |||||
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n", | |||||
" 为了解决以上的问题,我们的loss提供映射机制\n", | |||||
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n", | |||||
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n", | |||||
" (3) 对于Metric是同理的\n", | |||||
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n", | |||||
" \n", | |||||
" \n", | |||||
"\n", | |||||
"## 一些问题.\n", | |||||
"### 1. DataSet中为什么需要设置input和target\n", | |||||
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n", | |||||
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n", | |||||
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n", | |||||
" (a)Model.forward的output\n", | |||||
" (b)被设置为target的field\n", | |||||
" \n", | |||||
"\n", | |||||
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n", | |||||
" (1.1) 构建模型过程中,\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" \n", | |||||
"\n", | |||||
"\n", | |||||
"### 1. 加载数据到DataSet\n", | |||||
"### 2. 使用apply操作对DataSet进行预处理\n", | |||||
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n", | |||||
"### 3. 构建模型\n", | |||||
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n", | |||||
" 例如:\n", | |||||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||||
" def forward(self, x, seq_lens):\n", | |||||
" pass\n", | |||||
" 我们是通过形参名称进行匹配的field的\n", | |||||
" (3.2) 模型的forward的output需要是dict类型的。\n", | |||||
" 建议将输出设置为{\"pred\": xx}.\n", | |||||
" \n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -6,7 +6,7 @@ | |||||
"collapsed": true | "collapsed": true | ||||
}, | }, | ||||
"source": [ | "source": [ | ||||
"# FastNLP 1分钟上手教程" | |||||
"# fastNLP 1分钟上手教程" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -19,14 +19,14 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": 1, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stderr", | "name": "stderr", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||||
"c:\\users\\zyfeng\\miniconda3\\envs\\fastnlp\\lib\\site-packages\\tqdm\\autonotebook\\__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||||
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | " \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | ||||
] | ] | ||||
} | } | ||||
@@ -37,26 +37,23 @@ | |||||
"\n", | "\n", | ||||
"from fastNLP import DataSet\n", | "from fastNLP import DataSet\n", | ||||
"\n", | "\n", | ||||
"# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||||
"win_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||||
"ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')" | |||||
"data_path = \"./sample_data/tutorial_sample_dataset.csv\"\n", | |||||
"ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\\t')" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"execution_count": 2, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'raw_sentence': this quiet , introspective and entertaining independent is worth seeking .,\n", | |||||
"'label': 4,\n", | |||||
"'label_seq': 4,\n", | |||||
"'words': ['this', 'quiet', ',', 'introspective', 'and', 'entertaining', 'independent', 'is', 'worth', 'seeking', '.']}" | |||||
"{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n", | |||||
"'label': 4 type=str}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 8, | |||||
"execution_count": 2, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | "output_type": "execute_result" | ||||
} | } | ||||
@@ -78,7 +75,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"execution_count": 3, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -94,7 +91,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"execution_count": 4, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
@@ -115,7 +112,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": 5, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -138,7 +135,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 62, | |||||
"execution_count": 6, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
@@ -156,33 +153,46 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 63, | |||||
"execution_count": 7, | |||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"training epochs started 2018-12-07 14:03:41\n" | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\twords: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) \n", | |||||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 25]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2019-01-12 17-00-48\n" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "23979df0f63e446fbb0406b919b91dd3", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | "text/plain": [ | ||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…" | "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…" | ||||
] | ] | ||||
}, | }, | ||||
"execution_count": 0, | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "execute_result" | |||||
"output_type": "display_data" | |||||
}, | }, | ||||
{ | { | ||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087\n", | |||||
"Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826\n", | |||||
"Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696\n", | |||||
"Evaluation at Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.173913\n", | |||||
"Evaluation at Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.26087\n", | |||||
"Evaluation at Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.304348\n", | |||||
"\n", | |||||
"In Epoch:3/Step:6, got best dev performance:AccuracyMetric: acc=0.304348\n", | |||||
"Reloaded the best model.\n", | |||||
"Train finished!\n" | "Train finished!\n" | ||||
] | ] | ||||
} | } |
@@ -1,101 +0,0 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"source": [ | |||||
"## FastNLP 进阶教程\n", | |||||
"本教程阅读时间平均30分钟" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 数据部分\n", | |||||
"### DataSet\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Instance" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Vocabulary" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 模型部分\n", | |||||
"### model" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 训练测试部分\n", | |||||
"### Loss" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Metric" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Trainer" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### Tester" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 2", | |||||
"language": "python", | |||||
"name": "python2" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 2 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython2", | |||||
"version": "2.7.6" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 0 | |||||
} |
@@ -0,0 +1,8 @@ | |||||
[esim_model] | |||||
embed_dim = 300 | |||||
hidden_size = 300 | |||||
batch_first = true | |||||
dropout = 0.3 | |||||
num_classes = 3 | |||||
gpu = true | |||||
batch_size = 32 |
@@ -0,0 +1,100 @@ | |||||
A person is training his horse for a competition . | |||||
A person is at a diner , ordering an omelette . | |||||
A person is outdoors , on a horse . | |||||
They are smiling at their parents | |||||
There are children present | |||||
The kids are frowning | |||||
The boy skates down the sidewalk . | |||||
The boy does a skateboarding trick . | |||||
The boy is wearing safety equipment . | |||||
An older man drinks his juice as he waits for his daughter to get off work . | |||||
A boy flips a burger . | |||||
An elderly man sits in a small shop . | |||||
Some women are hugging on vacation . | |||||
The women are sleeping . | |||||
There are women showing affection . | |||||
The people are eating omelettes . | |||||
The people are sitting at desks in school . | |||||
The diners are at a restaurant . | |||||
A man is drinking juice . | |||||
Two women are at a restaurant drinking wine . | |||||
A man in a restaurant is waiting for his meal to arrive . | |||||
A blond man getting a drink of water from a fountain in the park . | |||||
A blond man wearing a brown shirt is reading a book on a bench in the park | |||||
A blond man drinking water from a fountain . | |||||
The friends scowl at each other over a full dinner table . | |||||
There are two woman in this picture . | |||||
The friends have just met for the first time in 20 years , and have had a great time catching up . | |||||
The two sisters saw each other across the crowded diner and shared a hug , both clutching their doggie bags . | |||||
Two groups of rival gang members flipped each other off . | |||||
Two women hug each other . | |||||
A team is trying to score the games winning out . | |||||
A team is trying to tag a runner out . | |||||
A team is playing baseball on Saturn . | |||||
A school hosts a basketball game . | |||||
A high school is hosting an event . | |||||
A school is hosting an event . | |||||
The women do not care what clothes they wear . | |||||
Women are waiting by a tram . | |||||
The women enjoy having a good fashion sense . | |||||
A child with mom and dad , on summer vacation at the beach . | |||||
A family of three is at the beach . | |||||
A family of three is at the mall shopping . | |||||
The people waiting on the train are sitting . | |||||
There are people just getting on a train | |||||
There are people waiting on a train . | |||||
A couple are playing with a young child outside . | |||||
A couple are playing frisbee with a young child at the beach . | |||||
A couple watch a little girl play by herself on the beach . | |||||
The family is sitting down for dinner . | |||||
The family is outside . | |||||
The family is on vacation . | |||||
The people are standing still on the curb . | |||||
Near a couple of restaurants , two people walk across the street . | |||||
The couple are walking across the street together . | |||||
The woman is nake . | |||||
The woman is cold . | |||||
The woman is wearing green . | |||||
The man with the sign is caucasian . | |||||
They are protesting outside the capital . | |||||
A woman in white . | |||||
A man is advertising for a restaurant . | |||||
The woman is wearing black . | |||||
A man and a woman walk down a crowded city street . | |||||
The woman is wearing white . | |||||
They are working for John 's Pizza . | |||||
Olympic swimming . | |||||
A man and a soman are eating together at John 's Pizza and Gyro . | |||||
They are walking with a sign . | |||||
The woman is waiting for a friend . | |||||
The man is sitting down while he has a sign for John 's Pizza and Gyro in his arms . | |||||
The woman and man are outdoors . | |||||
A woman ordering pizza . | |||||
The people are related . | |||||
Two adults run across the street to get away from a red shirted person chasing them . | |||||
The adults are both male and female . | |||||
Two people walk home after a tasty steak dinner . | |||||
Two adults swimming in water | |||||
Two adults walk across a street . | |||||
Two people ride bicycles into a tunnel . | |||||
Two people walk away from a restaurant across a street . | |||||
Two adults walking across a road near the convicted prisoner dressed in red | |||||
Two friends cross a street . | |||||
Some people board a train . | |||||
Two adults walk across the street . | |||||
Two adults walking across a road | |||||
There are no women in the picture . | |||||
Two adults walk across the street to get away from a red shirted person who is chasing them . | |||||
A married couple is sleeping . | |||||
A female is next to a man . | |||||
A married couple is walking next to each other . | |||||
Nobody has food . | |||||
A woman eats a banana and walks across a street , and there is a man trailing behind her . | |||||
The woman and man are playing baseball together . | |||||
two coworkers cross pathes on a street | |||||
A woman eats ice cream walking down the sidewalk , and there is another woman in front of her with a purse . | |||||
The mans briefcase is for work . | |||||
A person eating . | |||||
A person that is hungry . | |||||
An actress and her favorite assistant talk a walk in the city . | |||||
a woman eating a banana crosses a street |
@@ -0,0 +1,100 @@ | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
0 | |||||
2 | |||||
2 | |||||
0 | |||||
1 | |||||
1 | |||||
2 | |||||
1 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
0 | |||||
2 | |||||
1 | |||||
1 | |||||
2 | |||||
0 | |||||
2 | |||||
0 | |||||
1 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
0 | |||||
2 | |||||
2 | |||||
1 | |||||
0 | |||||
2 | |||||
0 | |||||
1 | |||||
1 | |||||
0 | |||||
2 | |||||
1 | |||||
0 | |||||
0 | |||||
0 | |||||
1 | |||||
2 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
1 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
0 | |||||
2 | |||||
1 | |||||
0 | |||||
1 | |||||
2 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
2 | |||||
0 | |||||
1 | |||||
1 | |||||
2 | |||||
0 | |||||
0 | |||||
2 | |||||
1 | |||||
2 | |||||
0 | |||||
1 | |||||
2 | |||||
0 | |||||
2 | |||||
1 | |||||
2 | |||||
1 | |||||
0 | |||||
1 | |||||
1 | |||||
0 |
@@ -0,0 +1,100 @@ | |||||
A person on a horse jumps over a broken down airplane . | |||||
A person on a horse jumps over a broken down airplane . | |||||
A person on a horse jumps over a broken down airplane . | |||||
Children smiling and waving at camera | |||||
Children smiling and waving at camera | |||||
Children smiling and waving at camera | |||||
A boy is jumping on skateboard in the middle of a red bridge . | |||||
A boy is jumping on skateboard in the middle of a red bridge . | |||||
A boy is jumping on skateboard in the middle of a red bridge . | |||||
An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background . | |||||
An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background . | |||||
An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background . | |||||
Two blond women are hugging one another . | |||||
Two blond women are hugging one another . | |||||
Two blond women are hugging one another . | |||||
A few people in a restaurant setting , one of them is drinking orange juice . | |||||
A few people in a restaurant setting , one of them is drinking orange juice . | |||||
A few people in a restaurant setting , one of them is drinking orange juice . | |||||
An older man is drinking orange juice at a restaurant . | |||||
An older man is drinking orange juice at a restaurant . | |||||
An older man is drinking orange juice at a restaurant . | |||||
A man with blond-hair , and a brown shirt drinking out of a public water fountain . | |||||
A man with blond-hair , and a brown shirt drinking out of a public water fountain . | |||||
A man with blond-hair , and a brown shirt drinking out of a public water fountain . | |||||
Two women who just had lunch hugging and saying goodbye . | |||||
Two women who just had lunch hugging and saying goodbye . | |||||
Two women who just had lunch hugging and saying goodbye . | |||||
Two women , holding food carryout containers , hug . | |||||
Two women , holding food carryout containers , hug . | |||||
Two women , holding food carryout containers , hug . | |||||
A Little League team tries to catch a runner sliding into a base in an afternoon game . | |||||
A Little League team tries to catch a runner sliding into a base in an afternoon game . | |||||
A Little League team tries to catch a runner sliding into a base in an afternoon game . | |||||
The school is having a special event in order to show the american culture on how other cultures are dealt with in parties . | |||||
The school is having a special event in order to show the american culture on how other cultures are dealt with in parties . | |||||
The school is having a special event in order to show the american culture on how other cultures are dealt with in parties . | |||||
High fashion ladies wait outside a tram beside a crowd of people in the city . | |||||
High fashion ladies wait outside a tram beside a crowd of people in the city . | |||||
High fashion ladies wait outside a tram beside a crowd of people in the city . | |||||
A man , woman , and child enjoying themselves on a beach . | |||||
A man , woman , and child enjoying themselves on a beach . | |||||
A man , woman , and child enjoying themselves on a beach . | |||||
People waiting to get on a train or just getting off . | |||||
People waiting to get on a train or just getting off . | |||||
People waiting to get on a train or just getting off . | |||||
A couple playing with a little boy on the beach . | |||||
A couple playing with a little boy on the beach . | |||||
A couple playing with a little boy on the beach . | |||||
A couple play in the tide with their young son . | |||||
A couple play in the tide with their young son . | |||||
A couple play in the tide with their young son . | |||||
A man and a woman cross the street in front of a pizza and gyro restaurant . | |||||
A man and a woman cross the street in front of a pizza and gyro restaurant . | |||||
A man and a woman cross the street in front of a pizza and gyro restaurant . | |||||
A woman in a green jacket and hood over her head looking towards a valley . | |||||
A woman in a green jacket and hood over her head looking towards a valley . | |||||
A woman in a green jacket and hood over her head looking towards a valley . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Woman in white in foreground and a man slightly behind walking with a sign for John 's Pizza and Gyro in the background . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
Two adults , one female in white , with shades and one male , gray clothes , walking across a street , away from a eatery with a blurred image of a dark colored red shirted person in the foreground . | |||||
A woman wearing all white and eating , walks next to a man holding a briefcase . | |||||
A woman wearing all white and eating , walks next to a man holding a briefcase . | |||||
A woman wearing all white and eating , walks next to a man holding a briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . | |||||
A woman is walking across the street eating a banana , while a man is following with his briefcase . |
@@ -0,0 +1,77 @@ | |||||
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1 | |||||
This quiet , introspective and entertaining independent is worth seeking . 4 | |||||
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1 | |||||
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3 | |||||
Aggressive self-glorification and a manipulative whitewash . 1 | |||||
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4 | |||||
Narratively , Trouble Every Day is a plodding mess . 1 | |||||
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3 | |||||
But it does n't leave you with much . 1 | |||||
You could hate it for the same reason . 1 | |||||
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1 | |||||
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1 | |||||
The performances are an absolute joy . 4 | |||||
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1 | |||||
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1 | |||||
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2 | |||||
Nothing more than a run-of-the-mill action flick . 2 | |||||
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0 | |||||
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2 | |||||
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2 | |||||
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2 | |||||
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1 | |||||
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1 | |||||
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 |
@@ -0,0 +1,77 @@ | |||||
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1 | |||||
This quiet , introspective and entertaining independent is worth seeking . 4 | |||||
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1 | |||||
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3 | |||||
Aggressive self-glorification and a manipulative whitewash . 1 | |||||
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4 | |||||
Narratively , Trouble Every Day is a plodding mess . 1 | |||||
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3 | |||||
But it does n't leave you with much . 1 | |||||
You could hate it for the same reason . 1 | |||||
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1 | |||||
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1 | |||||
The performances are an absolute joy . 4 | |||||
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1 | |||||
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1 | |||||
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2 | |||||
Nothing more than a run-of-the-mill action flick . 2 | |||||
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0 | |||||
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2 | |||||
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2 | |||||
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2 | |||||
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1 | |||||
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1 | |||||
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 | |||||
I still like Moonlight Mile , better judgment be damned . 3 | |||||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||||
a bilingual charmer , just like the woman who inspired it 3 | |||||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||||
It 's everything you 'd expect -- but nothing more . 2 | |||||
Best indie of the year , so far . 4 | |||||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||||
The plot is romantic comedy boilerplate from start to finish . 2 | |||||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||||
A film that clearly means to preach exclusively to the converted . 2 |
@@ -0,0 +1,283 @@ | |||||
# fastNLP开发者指南 | |||||
#### 本教程涉及以下类: | |||||
- DataSet | |||||
- Sampler | |||||
- Batch | |||||
- Model | |||||
- Loss | |||||
- Metric | |||||
- Trainer | |||||
- Tester | |||||
#### DataSet: 用于承载数据。 | |||||
1. DataSet里面每个元素只能是以下的三类`np.float64`, `np.int64`, `np.str`。如果传入的数据是`int`则被转换为`np.int64`, `float`被转为`np.float64`。 | |||||
2. DataSet可以将field设置为input或者target。其中被设置为input的field会被传递给Model.forward, 这个过程中我们是通过键匹配完成传递的。举例来说,假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 | |||||
- 函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 | |||||
- 函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 | |||||
- 函数是Model.forward(self, x1, **kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, **kwargs)就会发生报错,因为没有'x4'。 | |||||
3. 对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。后面会讲为什么target可以不强制。 | |||||
DataSet应该是不需要单独再开发的,如果有不能满足的场景,请在开发群提出或者github提交issue。 | |||||
#### Sampler: 给定一个DataSet,返回一个序号的list,Batch按照这个list输出数据。 | |||||
Sampler需要继承fastNLP.core.sampler.BaseSampler | |||||
```python | |||||
class BaseSampler(object): | |||||
"""The base class of all samplers. | |||||
Sub-classes must implement the __call__ method. | |||||
__call__ takes a DataSet object and returns a list of int - the sampling indices. | |||||
""" | |||||
def __call__(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
# 子类需要复写__call__方法。这个函数只能有一个必选参数, 且必须是DataSet类别, 否则Trainer没法调 | |||||
class SonSampler(BaseSampler): | |||||
def __init__(self, xxx): | |||||
# 可以实现init也不可以不实现。 | |||||
pass | |||||
def __call__(self, data_set): | |||||
pass | |||||
``` | |||||
#### Batch: 将DataSet中设置为input和target的field取出来构成batch_x, batch_y | |||||
并且根据情况(主要根据数据类型能不能转为Tensor)将数据转换为pytorch的Tensor。batch中sample的取出顺序是由Sampler决定的。 | |||||
Sampler是传入一个DataSet,返回一个与DataSet等长的序号list,Batch一次会取出batch_size个sample(最后一个batch可能数量不足batch_size个)。 | |||||
举例: | |||||
1. SequentialSampler是顺序采样 | |||||
假设传入的DataSet长度是100, SequentialSampler返回的序号list就是[0, 1, ...,98, 99]. batch_size如果被设置为4,那么第一个batch所获取的instance就是[0, 1, 2, 3]这四个instance. 第二个batch所获取instace就是[4, 5, 6, 7], ...直到采完所有的sample。 | |||||
2. RandomSampler是随机采样 | |||||
假设传入的DataSet长度是100, RandomSampler返回的序号list可能是[0, 99, 20, 5, 3, 1, ...]. 依次按照batch_size的大小取出sample。 | |||||
Batch应该不需要继承与开发,如果你有特殊需求请在开发群里提出。 | |||||
#### Model:用户自定的Model | |||||
必须是nn.Module的子类 | |||||
1. 必须实现forward方法,并且forward方法不能出现*arg这种参数. 例如 | |||||
```python | |||||
def forward(self, word_seq, *args): #这是不允许的. | |||||
# ... | |||||
pass | |||||
``` | |||||
返回值必须是dict的 | |||||
```python | |||||
def forward(self, word_seq, seq_lens): | |||||
xxx = "xxx" | |||||
return {'pred': xxx} #return的值必须是dict的。里面的预测的key推荐使用pred,但是不做强制限制。输出元素数目不限。 | |||||
``` | |||||
2. 如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。 | |||||
#### Loss: 根据model.forward()返回的prediction(是一个dict)和batch_y计算相应的loss | |||||
1. 先介绍"键映射"。 如在DataSet, Model一节所看见的那样,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。计算的loss的时候,怎么才能知道从哪里取值呢? | |||||
这里以CrossEntropyLoss为例,一般情况下, 计算CrossEntropy需要prediction和target两个值。而在CrossEntropyLoss初始化时可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与batch_y中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与batch_y,也可以只来自于forward的结果。 | |||||
2. 如何创建一个自己的loss | |||||
- 使用fastNLP.LossInForward, 在model.forward()的结果中包含一个为loss的key。 | |||||
- trainer中使用loss(假设loss=CrossEntropyLoss())的时候其实是 | |||||
los = loss(prediction, batch_y),即直接调用的是`loss.__call__()`方法,但是CrossEntropyLoss里面并没有自己实现`__call__`方法,这是因为`__call__`在LossBase中实现了。所有的loss必须继承fastNLP.core.loss.LossBase, 下面先说一下LossBase的几个方法,见下一节。 | |||||
3. 尽量不要复写`__call__()`, `_init_param_map()`方法。 | |||||
```python | |||||
class LossBase(): | |||||
def __init__(self): | |||||
self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好 | |||||
self._checked = False # 这个参数可以忽略 | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
# 这个函数是用于注册Loss的“键映射”,有两种传值方法, | |||||
# 第一种是通过key_map传入dict,取值是用value到forward和batch_y取 | |||||
# key_map = {'pred': 'output', 'target': 'label'} | |||||
# 第二种是自己写 | |||||
# _init_param_map(pred='output', target='label') | |||||
# 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是get_loss | |||||
# 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它loss参数不要传入。如果传入(pred=None, target=None) | |||||
# 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。 | |||||
# 但这个参数不是必须要调用的。 | |||||
def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的 | |||||
# 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算loss所必须的key等。检查通过,则调用get_loss | |||||
# 方法。 | |||||
fast_param = self._fast_param_map(predict_dict, target_dict): | |||||
if fast_param: | |||||
return self.get_loss(**fast_param) | |||||
# 如果没有fast_param则通过匹配参数然后调用get_loss完成 | |||||
xxxx | |||||
return loss # 返回为Tensor的loss | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
# 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过"键映射",比如计算loss时,pred_dict只有一个元素, | |||||
# target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算loss, 基类判断了这种情况(可能还有其它无歧义的情况)。 | |||||
# 即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误"键映射"的情况也可以直接计算loss。 | |||||
# 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败, | |||||
# __call__方法会继续执行。 | |||||
def get_loss(self, *args, **kwargs): | |||||
# 这个是一定需要实现的,计算loss的地方。 | |||||
# (1) get_loss中一定不能包含*arg这种参数形式。 | |||||
# (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数 | |||||
raise NotImplementedError | |||||
# 下面使用L1Loss举例 | |||||
class L1Loss(LossBase): # 继承LossBase | |||||
# 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与get_loss需要参数名是对应的 | |||||
def __init__(self, pred=None, target=None): | |||||
super(L1Loss, self).__init__() | |||||
# 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于 | |||||
# “键映射"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则 | |||||
# 不要将threshold传入_init_param_map. | |||||
self._init_param_map(pred=pred, target=target) | |||||
def get_loss(self, pred, target): | |||||
# 这里'pred', 'target'必须和初始化的映射是一致的。 | |||||
return F.l1_loss(input=pred, target=target) #直接返回一个loss即可 | |||||
``` | |||||
### Metric: 根据Model.forward()或者Model.predict()的结果计算metric | |||||
metric的设计和loss的设计类似。都是传入pred_dict与target_dict进行计算。但是metric的pred_dict来源可能是Model.forward的返回值, 也可能是Model.predict(如果Model具有predict方法则会调用predict方法)的返回值,下面统一用pred_dict代替。 | |||||
1. 这里的"键映射"与loss的"键映射"是类似的。举例来说,若Metric(pred='output', target='label'),则使用'output'到pred_dict和target_dict中寻找pred, 用'label'寻找target。 | |||||
2. 如何创建一个自己的Metric方法 | |||||
Metric与loss的计算不同在于,Metric的计算有两个步骤。 | |||||
- **每个batch的输出**都会调用Metric的``__call__(pred_dict, target_dict)``方法,而``__call__``方法会调用evaluate()(需要实现)方法。 | |||||
- 在所有batch传入之后,调用Metric的get_metric()方法得到最终的metric值。 | |||||
- 所以Metric在调用evaluate方法时,根据拿到的数据: pred_dict与batch_y, 改变自己的状态(比如累加正确的次数,总的sample数等)。在调用get_metric()的时候给出一个最终计算结果。 | |||||
所有的Metric必须继承自fastNLP.core.metrics.MetricBase. 例子见下一个cell | |||||
3. 尽量不要复写``__call__()``,``_init_param_map()``方法。 | |||||
```python | |||||
class MetricBase: | |||||
def __init__(self): | |||||
self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好 | |||||
self._checked = False # 这个参数可以忽略 | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
# 这个函数是用于注册Metric的“键映射”,有两种传值方法, | |||||
# 第一种是通过key_map传入dict,取值是用value到forward和batch_y取 | |||||
# key_map = {'pred': 'output', 'target': 'label'} | |||||
# 第二种是自己写(建议使用改种方式) | |||||
# _init_param_map(pred='output', target='label') | |||||
# 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是evaluate() | |||||
# 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它evaluate参数不要传入。如果传入(pred=None, target=None) | |||||
# 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。 | |||||
# 但这个参数不是必须要调用的。 | |||||
pass | |||||
def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的 | |||||
# 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算evaluate所必须的key等。检查通过,则调用 | |||||
# evaluate方法。 | |||||
fast_param = self._fast_param_map(predict_dict, target_dict): | |||||
if fast_param: | |||||
return self.evaluate(**fast_param) | |||||
# 如果没有fast_param则通过匹配参数然后调用get_loss完成 | |||||
# xxxx | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
# 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过"键映射",比如evaluate时,pred_dict只有一个元素, | |||||
# target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算metric, 基类判断了这种情况(可能还有其它无歧义的 | |||||
# 情况)。即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误"键映射"的情况也可以直接计算metric。 | |||||
# 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败, | |||||
# __call__方法会继续尝试匹配。 | |||||
pass | |||||
def evaluate(self, *args, **kwargs): | |||||
# 这个是一定需要实现的,累加metric状态 | |||||
# (1) evaluate()中一定不能包含*arg这种参数形式。 | |||||
# (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数 | |||||
raise NotImplementedError | |||||
def get_metric(self, reset=True): | |||||
# 这是一定需要实现的,获取最终的metric。返回值必须是一个dict。会在所有batch传入之后调用 | |||||
raise NotImplementedError | |||||
# 下面使用AccuracyMetric举例 | |||||
class AccuracyMetric(MetricBase): # MetricBase | |||||
# 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与evaluate()需要参数名是对应的 | |||||
def __init__(self, pred=None, target=None): | |||||
super(AccuracyMetric, self).__init__() | |||||
# 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于 | |||||
# “键映射"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则 | |||||
# 不要将threshold传入_init_param_map. | |||||
self._init_param_map(pred=pred, target=target) | |||||
self.total = 0 # 用于累加一共有多少sample | |||||
self.corr = 0 # 用于累加一共有多少正确的sample | |||||
def evaluate(self, pred, target): | |||||
# 对pred和target做一些基本的判断或者预处理等 | |||||
if pred.size()==target.size() and len(pred.size())=1: #如果pred已经做了argmax | |||||
pass | |||||
elif len(pred.size())==2 and len(target.size())==1: # pred还没有进行argmax | |||||
pred = pred.argmax(dim=1) | |||||
else: | |||||
raise ValueError("The shape of pred and target should be ((B, n_classes), (B, )) or (" | |||||
"(B,),(B,)).") | |||||
assert pred.size(0)==target.size(0), "Mismatch batch size." | |||||
# 进行相应的累加 | |||||
self.total += pred.size(0) | |||||
self.corr += torch.sum(torch.eq(pred, target).float()).item() | |||||
def get_metric(self, reset=True): | |||||
# reset用于指示是否清空累加信息。默认为True | |||||
# 这个函数需要返回dict,可以包含多个metric。 | |||||
metric = {} | |||||
metric['acc'] = self.corr/self.total | |||||
if reset: | |||||
self.total = 0 | |||||
self.corr = 0 | |||||
return metric | |||||
``` | |||||
#### Tester: 用于做evaluation,应该不需要更改 | |||||
重要的初始化参数有data, model, metric;比较重要的function是test()。 | |||||
test中的运行过程 | |||||
``` | |||||
predict_func = 如果有model.predict则为model.predict, 否则是model.forward | |||||
for batch_x, batch_y in batch: | |||||
# (1) 同步数据与model | |||||
# (2) 根据predict_func的参数从batch_x中取出数据传入到predict_func中,得到结果pred_dict | |||||
# (3) 调用metric(pred_dict, batch_y | |||||
# (4) 当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 | |||||
metric.get_metric() | |||||
``` | |||||
#### Trainer: 对训练过程的封装。 | |||||
里面比较重要的function是train() | |||||
train()中的运行过程 | |||||
``` | |||||
(1) 创建batch | |||||
batch = Batch(dataset, batch_size, sampler=sampler) | |||||
for batch_x, batch_y in batch: | |||||
# ... | |||||
batch_x,batch_y都是dict。batch_x是DataSet中被设置为input的field;batch_y是DataSet中被设置为target的field。 | |||||
两个dict中的key就是DataSet中的key,value会根据情况做好padding的tensor。 | |||||
(2)会将batch_x, batch_y中tensor移动到model所在的device | |||||
(3)根据model.forward的参数列表, 从batch_x中取出需要传递给forward的数据。 | |||||
(4)获取model.forward的输出结果pred_dict,并与batch_y一起传递给loss函数, 求得loss | |||||
(5)对loss进行反向梯度并更新参数 | |||||
(6) 如果有验证集,则需要做验证 | |||||
tester = Tester(model, dev_data,metric) | |||||
eval_results = tester.test() | |||||
(7) 如果eval_results是当前的最佳结果,则保存模型。 | |||||
``` | |||||
#### 其他 | |||||
Trainer中还提供了"预跑"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行"预跑"。 | |||||
check_code_level=0,1,2代表不同的提醒级别。 | |||||
目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。 | |||||
0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行 | |||||
"预跑"的主要目的有两个: | |||||
- 防止train完了之后进行evaluation的时候出现错误。之前的train就白费了 | |||||
- 由于存在"键映射",直接运行导致的报错可能不太容易debug,通过"预跑"过程的报错会有一些debug提示 | |||||
"预跑"会进行以下的操作: | |||||
- 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。 | |||||
- 将Model.foward的输出pred_dict与batch_y输入到loss中, 并尝试backward. 不会更新参数,而且grad会被清零 | |||||
如果传入了dev_data,还将进行metric的测试 | |||||
- 创建Tester,并传入少量数据,检测是否可以正常运行 | |||||
"预跑"操作是在Trainer初始化的时候执行的。 | |||||
正常情况下,应该不需要改动"预跑"的代码。但如果你遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。 | |||||