* pass all tests * prepare CWS & POS API * update tutorials * add README.md in tutorials/ & api/tags/v0.3.0^2
@@ -0,0 +1,33 @@ | |||
# fastNLP 高级接口 | |||
### 环境与配置 | |||
1. 系统环境:linux/ubuntu(推荐) | |||
2. 编程语言:Python>=3.6 | |||
3. Python包依赖 | |||
- **torch==1.0** | |||
- numpy>=1.14.2 | |||
### 中文分词 | |||
```python | |||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
from fastNLP.api import CWS | |||
cws = CWS(device='cpu') | |||
print(cws.predict(text)) | |||
``` | |||
### 中文分词+词性标注 | |||
```python | |||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
from fastNLP.api import POS | |||
pos = POS(device='cpu') | |||
print(pos.predict(text)) | |||
``` | |||
### 中文分词+词性标注+句法分析 | |||
敬请期待 | |||
完整样例见`examples.py` |
@@ -0,0 +1 @@ | |||
from .api import CWS, POS, Parser |
@@ -7,7 +7,7 @@ import os | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.model_zoo import load_url | |||
from fastNLP.api.utils import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
@@ -17,12 +17,14 @@ from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.api.processor import IndexerProcessor | |||
# TODO add pretrain urls | |||
model_urls = { | |||
'cws': "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl" | |||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", | |||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl", | |||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" | |||
} | |||
class API: | |||
def __init__(self): | |||
self.pipeline = None | |||
@@ -50,6 +52,7 @@ class POS(API): | |||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||
""" | |||
def __init__(self, model_path=None, device='cpu'): | |||
super(POS, self).__init__() | |||
if model_path is None: | |||
@@ -246,8 +249,8 @@ class Parser(API): | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
dataset.add_field('wp', pos_out) | |||
dataset.apply(lambda x: ['<BOS>']+[w.split('/')[0] for w in x['wp']], new_field_name='words') | |||
dataset.apply(lambda x: ['<BOS>']+[w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | |||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
@@ -328,35 +331,3 @@ class Analyzer: | |||
output_dict['parser'] = parser_output | |||
return output_dict | |||
if __name__ == "__main__": | |||
# pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||
# pos = POS(pos_model_path, device='cpu') | |||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
# '那么这款无人机到底有多厉害?'] | |||
# print(pos.test("/home/zyfeng/data/sample.conllx")) | |||
# print(pos.predict(s)) | |||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf_1_11.pkl' | |||
cws = CWS(device='cpu') | |||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
parser_path = '/home/yfshao/workdir/fastnlp/reproduction/Biaffine_parser/pipe.pkl' | |||
parser = Parser(parser_path, device='cpu') | |||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
print(cws.test('/home/hyan/ctb3/test.conllx')) | |||
print(cws.predict(s)) | |||
print(cws.predict('本品是一个抗酸抗胆汁的胃黏膜保护剂')) | |||
# parser = Parser(device='cpu') | |||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
# '那么这款无人机到底有多厉害?'] | |||
# print(parser.predict(s)) | |||
print(parser.predict(s)) |
@@ -0,0 +1,24 @@ | |||
""" | |||
api/example.py contains all API examples provided by fastNLP. | |||
It is used as a tutorial for API or a test script since it is difficult to test APIs in travis. | |||
""" | |||
from fastNLP.api import CWS, POS | |||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
def chinese_word_segmentation(): | |||
cws = CWS(device='cpu') | |||
print(cws.predict(text)) | |||
def pos_tagging(): | |||
pos = POS(device='cpu') | |||
print(pos.predict(text)) | |||
if __name__ == "__main__": | |||
pos_tagging() |
@@ -3,12 +3,11 @@ from . import decoder | |||
from . import encoder | |||
from .aggregator import * | |||
from .decoder import * | |||
from .encoder import * | |||
from .dropout import TimestepDropout | |||
from .encoder import * | |||
__version__ = '0.0.0' | |||
__all__ = ['encoder', | |||
'decoder', | |||
'aggregator', | |||
'TimestepDropout'] | |||
'aggregator'] |
@@ -1,4 +1,2 @@ | |||
from .CRF import ConditionalRandomField | |||
from .MLP import MLP | |||
__all__ = ["ConditionalRandomField", "MLP"] |
@@ -0,0 +1,12 @@ | |||
# fastNLP 教程 | |||
### 上手教程 Quick Start | |||
- 一分钟上手:`fastnlp_1min_tutorial.ipynb`  | |||
- 十分钟上手:`fastnlp_10min_tutorial.ipynb`  | |||
### 进阶教程 Advanced Tutorial | |||
- `fastnlp_advanced_tutorial/advance_tutorial.ipynb`  | |||
### 开发者指南 Developer Guide | |||
- `tutorial_for_developer.md`  |
@@ -4,12 +4,29 @@ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"fastNLP上手教程\n", | |||
"fastNLP10 分钟上手教程\n", | |||
"-------\n", | |||
"\n", | |||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"如果您还没有通过pip安装fastNLP,可以执行下面的操作加载当前模块" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"import sys\n", | |||
"sys.path.append(\"../\")" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
@@ -24,21 +41,14 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"execution_count": 6, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"8529" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"77\n" | |||
] | |||
} | |||
], | |||
@@ -47,27 +57,23 @@ | |||
"from fastNLP import Instance\n", | |||
"\n", | |||
"# 从csv读取数据到DataSet\n", | |||
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", | |||
"dataset = DataSet.read_csv('sample_data/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", | |||
"print(len(dataset))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 10, | |||
"execution_count": 7, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||
"'label': 1 type=str}\n", | |||
"{'raw_sentence': The plot is romantic comedy boilerplate from start to finish . type=str,\n", | |||
"'label': 2 type=str}\n" | |||
] | |||
} | |||
], | |||
@@ -91,16 +97,17 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"execution_count": 8, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'raw_sentence': fake data,\n'label': 0}" | |||
"{'raw_sentence': fake data type=str,\n", | |||
"'label': 0 type=str}" | |||
] | |||
}, | |||
"execution_count": 11, | |||
"execution_count": 8, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
@@ -121,21 +128,15 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"execution_count": 9, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||
"'label': 1 type=str}\n" | |||
] | |||
} | |||
], | |||
@@ -147,21 +148,15 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 13, | |||
"execution_count": 10, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||
"'label': 1 type=int}\n" | |||
] | |||
} | |||
], | |||
@@ -173,21 +168,16 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 14, | |||
"execution_count": 11, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||
"'label': 1 type=int,\n", | |||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'] type=list}\n" | |||
] | |||
} | |||
], | |||
@@ -201,21 +191,17 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 15, | |||
"execution_count": 12, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . type=str,\n", | |||
"'label': 1 type=int,\n", | |||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'] type=list,\n", | |||
"'seq_len': 37 type=int}\n" | |||
] | |||
} | |||
], | |||
@@ -235,25 +221,19 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 16, | |||
"execution_count": 13, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"8358" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"77\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# 删除低于某个长度的词语\n", | |||
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", | |||
"print(len(dataset))" | |||
] | |||
@@ -269,7 +249,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 17, | |||
"execution_count": 14, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -283,35 +263,15 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 18, | |||
"execution_count": 15, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"5851" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"2507" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"54\n", | |||
"23\n" | |||
] | |||
} | |||
], | |||
@@ -335,21 +295,17 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 19, | |||
"execution_count": 16, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': the project 's filmmakers forgot to include anything even halfway scary as they poorly rejigger fatal attraction into a high school setting .,\n'label': 0,\n'words': [4, 423, 9, 316, 1, 8, 1, 312, 72, 1478, 885, 14, 86, 725, 1, 1913, 1431, 53, 5, 455, 736, 1, 2],\n'seq_len': 23}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"{'raw_sentence': a welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . type=str,\n", | |||
"'label': 3 type=int,\n", | |||
"'words': [4, 1, 1, 18, 1, 1, 13, 1, 1, 1, 8, 26, 1, 5, 35, 1, 11, 4, 1, 10, 1, 10, 1, 1, 1, 2] type=list,\n", | |||
"'seq_len': 26 type=int}\n" | |||
] | |||
} | |||
], | |||
@@ -369,6 +325,23 @@ | |||
"print(test_data[0])" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具\n", | |||
"from fastNLP.core.batch import Batch\n", | |||
"from fastNLP.core.sampler import RandomSampler\n", | |||
"\n", | |||
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n", | |||
"for batch_x, batch_y in batch_iterator:\n", | |||
" print(\"batch_x has: \", batch_x)\n", | |||
" print(\"batch_y has: \", batch_y)\n", | |||
" break" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
@@ -379,16 +352,32 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 20, | |||
"execution_count": 17, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(3459, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||
"CNNText(\n", | |||
" (embed): Embedding(\n", | |||
" (embed): Embedding(59, 50, padding_idx=0)\n", | |||
" (dropout): Dropout(p=0.0)\n", | |||
" )\n", | |||
" (conv_pool): ConvMaxpool(\n", | |||
" (convs): ModuleList(\n", | |||
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", | |||
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", | |||
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", | |||
" )\n", | |||
" )\n", | |||
" (dropout): Dropout(p=0.1)\n", | |||
" (fc): Linear(\n", | |||
" (linear): Linear(in_features=12, out_features=5, bias=True)\n", | |||
" )\n", | |||
")" | |||
] | |||
}, | |||
"execution_count": 20, | |||
"execution_count": 17, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
@@ -459,7 +448,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 21, | |||
"execution_count": 18, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -496,7 +485,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 22, | |||
"execution_count": 19, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -519,7 +508,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 23, | |||
"execution_count": 20, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -528,149 +517,61 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 24, | |||
"execution_count": 21, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"training epochs started 2018-12-07 14:11:31" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"input fields after batch(if batch size is 2):\n", | |||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||
"target fields after batch(if batch size is 2):\n", | |||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
"\n", | |||
"training epochs started 2019-01-12 17-07-51\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=915), HTML(value='')), layout=Layout(display=…" | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=10), HTML(value='')), layout=Layout(display='…" | |||
] | |||
}, | |||
"execution_count": 0, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 1/5. Step:183/915. AccuracyMetric: acc=0.350367" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
"Evaluation at Epoch 1/5. Step:2/10. AccuracyMetric: acc=0.425926\n", | |||
"Evaluation at Epoch 2/5. Step:4/10. AccuracyMetric: acc=0.425926\n", | |||
"Evaluation at Epoch 3/5. Step:6/10. AccuracyMetric: acc=0.611111\n", | |||
"Evaluation at Epoch 4/5. Step:8/10. AccuracyMetric: acc=0.648148\n", | |||
"Evaluation at Epoch 5/5. Step:10/10. AccuracyMetric: acc=0.703704\n", | |||
"\n", | |||
"In Epoch:5/Step:10, got best dev performance:AccuracyMetric: acc=0.703704\n", | |||
"Reloaded the best model.\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 2/5. Step:366/915. AccuracyMetric: acc=0.409332" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 3/5. Step:549/915. AccuracyMetric: acc=0.572552" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 4/5. Step:732/915. AccuracyMetric: acc=0.711331" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 5/5. Step:915/915. AccuracyMetric: acc=0.801572" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
"data": { | |||
"text/plain": [ | |||
"{'best_eval': {'AccuracyMetric': {'acc': 0.703704}},\n", | |||
" 'best_epoch': 5,\n", | |||
" 'best_step': 10,\n", | |||
" 'seconds': 0.62}" | |||
] | |||
}, | |||
"execution_count": 21, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"# 实例化Trainer,传入模型和数据,进行训练\n", | |||
"# 先在test_data拟合\n", | |||
"# 先在test_data拟合(确保模型的实现是正确的)\n", | |||
"copy_model = deepcopy(model)\n", | |||
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", | |||
" loss=loss,\n", | |||
@@ -683,143 +584,43 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 25, | |||
"execution_count": 22, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"training epochs started 2018-12-07 14:12:21" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"input fields after batch(if batch size is 2):\n", | |||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 20]) \n", | |||
"target fields after batch(if batch size is 2):\n", | |||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
"\n", | |||
"training epochs started 2019-01-12 17-09-05\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=395), HTML(value='')), layout=Layout(display=…" | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…" | |||
] | |||
}, | |||
"execution_count": 0, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 1/5. Step:79/395. AccuracyMetric: acc=0.250043" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 2/5. Step:158/395. AccuracyMetric: acc=0.280807" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 3/5. Step:237/395. AccuracyMetric: acc=0.280978" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 4/5. Step:316/395. AccuracyMetric: acc=0.285592" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 5/5. Step:395/395. AccuracyMetric: acc=0.278927" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
"Evaluation at Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.37037\n", | |||
"Evaluation at Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.37037\n", | |||
"Evaluation at Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.462963\n", | |||
"Evaluation at Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.425926\n", | |||
"Evaluation at Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.481481\n", | |||
"\n", | |||
"In Epoch:5/Step:5, got best dev performance:AccuracyMetric: acc=0.481481\n", | |||
"Reloaded the best model.\n", | |||
"Train finished!\n" | |||
] | |||
} | |||
], | |||
@@ -837,35 +638,16 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 26, | |||
"execution_count": 23, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"[tester] \nAccuracyMetric: acc=0.280636" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'AccuracyMetric': {'acc': 0.280636}}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
"[tester] \n", | |||
"AccuracyMetric: acc=0.481481\n", | |||
"{'AccuracyMetric': {'acc': 0.481481}}\n" | |||
] | |||
} | |||
], | |||
@@ -879,6 +661,75 @@ | |||
"print(acc)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"# In summary\n", | |||
"\n", | |||
"## fastNLP Trainer的伪代码逻辑\n", | |||
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n", | |||
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n", | |||
" 通过\n", | |||
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n", | |||
" 通过\n", | |||
" DataSet.set_target('label', flag=True)将'label'设置为target\n", | |||
"### 2. 初始化模型\n", | |||
" class Model(nn.Module):\n", | |||
" def __init__(self):\n", | |||
" xxx\n", | |||
" def forward(self, word_seq1, word_seq2):\n", | |||
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n", | |||
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n", | |||
" xxxx\n", | |||
" # 输出必须是一个dict\n", | |||
"### 3. Trainer的训练过程\n", | |||
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n", | |||
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n", | |||
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n", | |||
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n", | |||
" 为了解决以上的问题,我们的loss提供映射机制\n", | |||
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n", | |||
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n", | |||
" (3) 对于Metric是同理的\n", | |||
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n", | |||
" \n", | |||
" \n", | |||
"\n", | |||
"## 一些问题.\n", | |||
"### 1. DataSet中为什么需要设置input和target\n", | |||
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n", | |||
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n", | |||
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n", | |||
" (a)Model.forward的output\n", | |||
" (b)被设置为target的field\n", | |||
" \n", | |||
"\n", | |||
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n", | |||
" (1.1) 构建模型过程中,\n", | |||
" 例如:\n", | |||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||
" def forward(self, x, seq_lens):\n", | |||
" pass\n", | |||
" 我们是通过形参名称进行匹配的field的\n", | |||
" \n", | |||
"\n", | |||
"\n", | |||
"### 1. 加载数据到DataSet\n", | |||
"### 2. 使用apply操作对DataSet进行预处理\n", | |||
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n", | |||
"### 3. 构建模型\n", | |||
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n", | |||
" 例如:\n", | |||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||
" def forward(self, x, seq_lens):\n", | |||
" pass\n", | |||
" 我们是通过形参名称进行匹配的field的\n", | |||
" (3.2) 模型的forward的output需要是dict类型的。\n", | |||
" 建议将输出设置为{\"pred\": xx}.\n", | |||
" \n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, |
@@ -1,860 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"fastNLP上手教程\n", | |||
"-------\n", | |||
"\n", | |||
"fastNLP提供方便的数据预处理,训练和测试模型的功能" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"DataSet & Instance\n", | |||
"------\n", | |||
"\n", | |||
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", | |||
"\n", | |||
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import DataSet\n", | |||
"from fastNLP import Instance\n", | |||
"\n", | |||
"# 从csv读取数据到DataSet\n", | |||
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n", | |||
"dataset = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')\n", | |||
"print(dataset[0])" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'raw_sentence': fake data,\n'label': 0}" | |||
] | |||
}, | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"# DataSet.append(Instance)加入新数据\n", | |||
"\n", | |||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", | |||
"dataset[-1]" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# DataSet.apply(func, new_field_name)对数据预处理\n", | |||
"\n", | |||
"# 将所有数字转为小写\n", | |||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", | |||
"# label转int\n", | |||
"dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n", | |||
"# 使用空格分割句子\n", | |||
"dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)\n", | |||
"def split_sent(ins):\n", | |||
" return ins['raw_sentence'].split()\n", | |||
"dataset.apply(split_sent, new_field_name='words', is_input=True)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"# DataSet.drop(func)筛除数据\n", | |||
"# 删除低于某个长度的词语\n", | |||
"dataset.drop(lambda x: len(x['words']) <= 3)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Train size: " | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
" " | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"54" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Test size: " | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# 分出测试集、训练集\n", | |||
"\n", | |||
"test_data, train_data = dataset.split(0.3)\n", | |||
"print(\"Train size: \", len(test_data))\n", | |||
"print(\"Test size: \", len(train_data))" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Vocabulary\n", | |||
"------\n", | |||
"\n", | |||
"fastNLP中的Vocabulary轻松构建词表,将词转成数字" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"{'raw_sentence': the plot is romantic comedy boilerplate from start to finish .,\n'label': 2,\n'label_seq': 2,\n'words': ['the', 'plot', 'is', 'romantic', 'comedy', 'boilerplate', 'from', 'start', 'to', 'finish', '.'],\n'word_seq': [2, 13, 9, 24, 25, 26, 15, 27, 11, 28, 3]}" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import Vocabulary\n", | |||
"\n", | |||
"# 构建词表, Vocabulary.add(word)\n", | |||
"vocab = Vocabulary(min_freq=2)\n", | |||
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", | |||
"vocab.build_vocab()\n", | |||
"\n", | |||
"# index句子, Vocabulary.to_index(word)\n", | |||
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n", | |||
"\n", | |||
"\n", | |||
"print(test_data[0])" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"metadata": { | |||
"scrolled": true | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),\n", | |||
" list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],\n", | |||
" dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,\n", | |||
" 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,\n", | |||
" 8, 1611, 16, 21, 1039, 1, 2],\n", | |||
" [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0]])}\n", | |||
"batch_y has: {'label_seq': tensor([3, 2])}\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset\n", | |||
"from fastNLP.core.batch import Batch\n", | |||
"from fastNLP.core.sampler import RandomSampler\n", | |||
"\n", | |||
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n", | |||
"for batch_x, batch_y in batch_iterator:\n", | |||
" print(\"batch_x has: \", batch_x)\n", | |||
" print(\"batch_y has: \", batch_y)\n", | |||
" break" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"# Model\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 9, | |||
"metadata": { | |||
"collapsed": false | |||
}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"CNNText(\n (embed): Embedding(\n (embed): Embedding(77, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" | |||
] | |||
}, | |||
"execution_count": 9, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"# 定义一个简单的Pytorch模型\n", | |||
"\n", | |||
"from fastNLP.models import CNNText\n", | |||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", | |||
"model" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"Trainer & Tester\n", | |||
"------\n", | |||
"\n", | |||
"使用fastNLP的Trainer训练模型" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 11, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP import Trainer\n", | |||
"from copy import deepcopy\n", | |||
"from fastNLP import CrossEntropyLoss\n", | |||
"from fastNLP import AccuracyMetric" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 12, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"training epochs started 2018-12-07 14:07:20" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=20), HTML(value='')), layout=Layout(display='…" | |||
] | |||
}, | |||
"execution_count": 0, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.037037" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.296296" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.333333" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.555556" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.611111" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.481481" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.62963" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.685185" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.722222" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.777778" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# 进行overfitting测试\n", | |||
"copy_model = deepcopy(model)\n", | |||
"overfit_trainer = Trainer(model=copy_model, \n", | |||
" train_data=test_data, \n", | |||
" dev_data=test_data,\n", | |||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||
" metrics=AccuracyMetric(),\n", | |||
" n_epochs=10,\n", | |||
" save_path=None)\n", | |||
"overfit_trainer.train()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 14, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"training epochs started 2018-12-07 14:08:10" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…" | |||
] | |||
}, | |||
"execution_count": 0, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.037037" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.037037" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.037037" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.185185" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.240741" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\r" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Train finished!" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"# 实例化Trainer,传入模型和数据,进行训练\n", | |||
"trainer = Trainer(model=model, \n", | |||
" train_data=train_data, \n", | |||
" dev_data=test_data,\n", | |||
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", | |||
" metrics=AccuracyMetric(),\n", | |||
" n_epochs=5)\n", | |||
"trainer.train()\n", | |||
"print('Train finished!')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 15, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"[tester] \nAccuracyMetric: acc=0.240741" | |||
] | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import Tester\n", | |||
"\n", | |||
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())\n", | |||
"acc = tester.test()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"# In summary\n", | |||
"\n", | |||
"## fastNLP Trainer的伪代码逻辑\n", | |||
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n", | |||
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n", | |||
" 通过\n", | |||
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n", | |||
" 通过\n", | |||
" DataSet.set_target('label', flag=True)将'label'设置为target\n", | |||
"### 2. 初始化模型\n", | |||
" class Model(nn.Module):\n", | |||
" def __init__(self):\n", | |||
" xxx\n", | |||
" def forward(self, word_seq1, word_seq2):\n", | |||
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n", | |||
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n", | |||
" xxxx\n", | |||
" # 输出必须是一个dict\n", | |||
"### 3. Trainer的训练过程\n", | |||
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n", | |||
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n", | |||
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n", | |||
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n", | |||
" 为了解决以上的问题,我们的loss提供映射机制\n", | |||
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n", | |||
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n", | |||
" (3) 对于Metric是同理的\n", | |||
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n", | |||
" \n", | |||
" \n", | |||
"\n", | |||
"## 一些问题.\n", | |||
"### 1. DataSet中为什么需要设置input和target\n", | |||
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n", | |||
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n", | |||
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n", | |||
" (a)Model.forward的output\n", | |||
" (b)被设置为target的field\n", | |||
" \n", | |||
"\n", | |||
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n", | |||
" (1.1) 构建模型过程中,\n", | |||
" 例如:\n", | |||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||
" def forward(self, x, seq_lens):\n", | |||
" pass\n", | |||
" 我们是通过形参名称进行匹配的field的\n", | |||
" \n", | |||
"\n", | |||
"\n", | |||
"### 1. 加载数据到DataSet\n", | |||
"### 2. 使用apply操作对DataSet进行预处理\n", | |||
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n", | |||
"### 3. 构建模型\n", | |||
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n", | |||
" 例如:\n", | |||
" DataSet中x,seq_lens是input,那么forward就应该是\n", | |||
" def forward(self, x, seq_lens):\n", | |||
" pass\n", | |||
" 我们是通过形参名称进行匹配的field的\n", | |||
" (3.2) 模型的forward的output需要是dict类型的。\n", | |||
" 建议将输出设置为{\"pred\": xx}.\n", | |||
" \n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 3", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.6.7" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |
@@ -6,7 +6,7 @@ | |||
"collapsed": true | |||
}, | |||
"source": [ | |||
"# FastNLP 1分钟上手教程" | |||
"# fastNLP 1分钟上手教程" | |||
] | |||
}, | |||
{ | |||
@@ -19,14 +19,14 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 3, | |||
"execution_count": 1, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stderr", | |||
"output_type": "stream", | |||
"text": [ | |||
"/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||
"c:\\users\\zyfeng\\miniconda3\\envs\\fastnlp\\lib\\site-packages\\tqdm\\autonotebook\\__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |||
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" | |||
] | |||
} | |||
@@ -37,26 +37,23 @@ | |||
"\n", | |||
"from fastNLP import DataSet\n", | |||
"\n", | |||
"# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||
"win_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n", | |||
"ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')" | |||
"data_path = \"./sample_data/tutorial_sample_dataset.csv\"\n", | |||
"ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\\t')" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 8, | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"{'raw_sentence': this quiet , introspective and entertaining independent is worth seeking .,\n", | |||
"'label': 4,\n", | |||
"'label_seq': 4,\n", | |||
"'words': ['this', 'quiet', ',', 'introspective', 'and', 'entertaining', 'independent', 'is', 'worth', 'seeking', '.']}" | |||
"{'raw_sentence': This quiet , introspective and entertaining independent is worth seeking . type=str,\n", | |||
"'label': 4 type=str}" | |||
] | |||
}, | |||
"execution_count": 8, | |||
"execution_count": 2, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
@@ -78,7 +75,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 4, | |||
"execution_count": 3, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -94,7 +91,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": 4, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
@@ -115,7 +112,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": 5, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -138,7 +135,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 62, | |||
"execution_count": 6, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
@@ -156,33 +153,46 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 63, | |||
"execution_count": 7, | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"training epochs started 2018-12-07 14:03:41\n" | |||
"input fields after batch(if batch size is 2):\n", | |||
"\twords: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) \n", | |||
"\tword_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 25]) \n", | |||
"target fields after batch(if batch size is 2):\n", | |||
"\tlabel_seq: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||
"\n", | |||
"training epochs started 2019-01-12 17-00-48\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "23979df0f63e446fbb0406b919b91dd3", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…" | |||
] | |||
}, | |||
"execution_count": 0, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087\n", | |||
"Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826\n", | |||
"Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696\n", | |||
"Evaluation at Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.173913\n", | |||
"Evaluation at Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.26087\n", | |||
"Evaluation at Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.304348\n", | |||
"\n", | |||
"In Epoch:3/Step:6, got best dev performance:AccuracyMetric: acc=0.304348\n", | |||
"Reloaded the best model.\n", | |||
"Train finished!\n" | |||
] | |||
} |
@@ -1,101 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": { | |||
"collapsed": true | |||
}, | |||
"source": [ | |||
"## FastNLP 进阶教程\n", | |||
"本教程阅读时间平均30分钟" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"## 数据部分\n", | |||
"### DataSet\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Instance" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Vocabulary" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"## 模型部分\n", | |||
"### model" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"## 训练测试部分\n", | |||
"### Loss" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Metric" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Trainer" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Tester" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 2", | |||
"language": "python", | |||
"name": "python2" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 2 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython2", | |||
"version": "2.7.6" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 0 | |||
} |
@@ -1137,27 +1137,6 @@ | |||
"tester.test()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
@@ -1182,7 +1161,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.7.2" | |||
"version": "3.6.7" | |||
} | |||
}, | |||
"nbformat": 4, | |||
@@ -0,0 +1,77 @@ | |||
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1 | |||
This quiet , introspective and entertaining independent is worth seeking . 4 | |||
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1 | |||
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3 | |||
Aggressive self-glorification and a manipulative whitewash . 1 | |||
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4 | |||
Narratively , Trouble Every Day is a plodding mess . 1 | |||
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3 | |||
But it does n't leave you with much . 1 | |||
You could hate it for the same reason . 1 | |||
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1 | |||
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1 | |||
The performances are an absolute joy . 4 | |||
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3 | |||
I still like Moonlight Mile , better judgment be damned . 3 | |||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||
a bilingual charmer , just like the woman who inspired it 3 | |||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||
It 's everything you 'd expect -- but nothing more . 2 | |||
Best indie of the year , so far . 4 | |||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||
The plot is romantic comedy boilerplate from start to finish . 2 | |||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||
A film that clearly means to preach exclusively to the converted . 2 | |||
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1 | |||
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1 | |||
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2 | |||
Nothing more than a run-of-the-mill action flick . 2 | |||
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0 | |||
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2 | |||
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2 | |||
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2 | |||
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1 | |||
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1 | |||
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0 | |||
I still like Moonlight Mile , better judgment be damned . 3 | |||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||
a bilingual charmer , just like the woman who inspired it 3 | |||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||
It 's everything you 'd expect -- but nothing more . 2 | |||
Best indie of the year , so far . 4 | |||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||
The plot is romantic comedy boilerplate from start to finish . 2 | |||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||
A film that clearly means to preach exclusively to the converted . 2 | |||
I still like Moonlight Mile , better judgment be damned . 3 | |||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||
a bilingual charmer , just like the woman who inspired it 3 | |||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||
It 's everything you 'd expect -- but nothing more . 2 | |||
Best indie of the year , so far . 4 | |||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||
The plot is romantic comedy boilerplate from start to finish . 2 | |||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||
A film that clearly means to preach exclusively to the converted . 2 | |||
I still like Moonlight Mile , better judgment be damned . 3 | |||
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3 | |||
a bilingual charmer , just like the woman who inspired it 3 | |||
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2 | |||
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1 | |||
It 's everything you 'd expect -- but nothing more . 2 | |||
Best indie of the year , so far . 4 | |||
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3 | |||
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1 | |||
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2 | |||
The plot is romantic comedy boilerplate from start to finish . 2 | |||
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2 | |||
A film that clearly means to preach exclusively to the converted . 2 |
@@ -1,353 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### 一共会涉及到如下的几个类\n", | |||
"\n", | |||
"#### DataSet\n", | |||
"#### Sampler\n", | |||
"#### Batch\n", | |||
"#### Model\n", | |||
"#### Loss\n", | |||
"#### Metric\n", | |||
"#### Trainer\n", | |||
"#### Tester" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### 下面具体讲一下它们的作用" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### DataSet: 用于承载数据。\n", | |||
"(1) DataSet里面每个元素只能是以下的三类np.float64, np.int64, np.str。如果传入的数据是int则被转换为np.int64, float被转为np.float64。 \n", | |||
"(2) DataSet可以将field设置为input,target。其中被设置为input的field会被传递给Model.forward, 这个过程中我们是通过键匹配完成传递的。举例来说,假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 \n", | |||
"   (2.1)函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 \n", | |||
"   (2.2)函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 \n", | |||
"   (2.3)函数是Model.forward(self, x1, **kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, **kwargs)就会发生报错,因为没有'x4'。 \n", | |||
"(3) 对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。后面会讲为什么target可以不强制。 \n", | |||
"DataSet应该是不需要单独再开发的,如果有不能满足的场景,请在开发群提出或者github提交issue。" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Sampler: 给定一个DataSet,返回一个序号的list,Batch按照这个list输出数据。\n", | |||
"Sampler需要继承fastNLP.core.sampler.BaseSampler" | |||
] | |||
}, | |||
{ | |||
"cell_type": "raw", | |||
"metadata": {}, | |||
"source": [ | |||
"class BaseSampler(object):\n", | |||
"\"\"\"The base class of all samplers.\n", | |||
"\n", | |||
" Sub-classes must implement the __call__ method.\n", | |||
" __call__ takes a DataSet object and returns a list of int - the sampling indices.\n", | |||
"\"\"\"\n", | |||
"def __call__(self, *args, **kwargs):\n", | |||
" raise NotImplementedError\n", | |||
" \n", | |||
"# 子类需要复写__call__方法。这个函数只能有一个必选参数, 且必须是DataSet类别, 否则Trainer没法调\n", | |||
"class SonSampler(BaseSample):\n", | |||
" def __init__(self, xxx):\n", | |||
" # 可以实现init也不可以不实现。\n", | |||
" def __call__(self, data_set):\n", | |||
" pass" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Batch: 将DataSet中设置为input和target的field取出来构成batch_x, batch_y\n", | |||
"并且根据情况(主要根据数据类型能不能转为Tensor)将数据转换为pytorch的Tensor。batch中sample的取出顺序是由Sampler决定的。 \n", | |||
"Sampler是传入一个DataSet,返回一个与DataSet等长的序号list,Batch一次会取出batch_size个sample(最后一个batch可能数量不足batch_size个)。 \n", | |||
"举例: \n", | |||
"(1) SequentialSampler是顺序采样\n", | |||
" 假设传入的DataSet长度是100, SequentialSampler返回的序号list就是[0, 1, ...,98, 99]. batch_size如果被设置为4,那么第一个batch所获取的instance就是[0, 1, 2, 3]这四个instance. 第二个batch所获取instace就是[4, 5, 6, 7], ...直到采完所有的sample。 \n", | |||
"(2) RandomSampler是随机采样 \n", | |||
" 假设传入的DataSet长度是100, RandomSampler返回的序号list可能是[0, 99, 20, 5, 3, 1, ...]. 依次按照batch_size的大小取出sample。 \n", | |||
"Batch应该不需要继承与开发,如果你有特殊需求请在开发群里提出。" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Model:用户自定的Model\n", | |||
"必须是nn.Module的子类, \n", | |||
"(1) 必须实现forward方法,并且forward方法不能出现*arg这种参数. 例如 \n", | |||
"   def forward(self, word_seq, *args): #这是不允许的. \n", | |||
"      xxx \n", | |||
"返回值必须是dict的 \n", | |||
"   def forward(self, word_seq, seq_lens): \n", | |||
"      xxxx \n", | |||
"   return {'pred': xxx} #return的值必须是dict的。里面的预测的key推荐使用pred,但是不做强制限制。输出元素数目不限。 \n", | |||
"(2) 如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Loss: 根据model.forward()返回的prediction(是一个dict)和batch_y计算相应的loss。 \n", | |||
"(1) 先介绍\"键映射\"。 如在DataSet, Model一节所看见的那样,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。计算的loss的时候,怎么才能知道从哪里取值呢? \n", | |||
"这里以CrossEntropyLoss为例,一般情况下, 计算CrossEntropy需要prediction和target两个值。而在CrossEntropyLoss初始化时可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与batch_y中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与batch_y,也可以只来自于forward的结果。 \n", | |||
"(2)如何创建一个自己的loss \n", | |||
"   (2.1)使用fastNLP.LossInForward, 在model.forward()的结果中包含一个为loss的key。 \n", | |||
"   (2.2) trainer中使用loss(假设loss=CrossEntropyLoss())的时候其实是 \n", | |||
"    los = loss(prediction, batch_y)\n", | |||
" 即直接调用的是loss.\\__call__()方法,但是CrossEntropyLoss里面并没有自己实现\\__call__方法,这是因为\\__call__在LossBase中实现了。所有的loss必须继承fastNLP.core.loss.LossBase, 下面先说一下LossBase的几个方法,见下一个cell。 \n", | |||
"(3) 尽量不要复写\\__call__(), _init_param_map()方法。" | |||
] | |||
}, | |||
{ | |||
"cell_type": "raw", | |||
"metadata": {}, | |||
"source": [ | |||
"class LossBase():\n", | |||
" def __init__(self):\n", | |||
" self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好\n", | |||
" self._checked = False # 这个参数可以忽略\n", | |||
"\n", | |||
" def _init_param_map(self, key_map=None, **kwargs):\n", | |||
" # 这个函数是用于注册Loss的“键映射”,有两种传值方法,\n", | |||
" # 第一种是通过key_map传入dict,取值是用value到forward和batch_y取\n", | |||
" # key_map = {'pred': 'output', 'target': 'label'} \n", | |||
" # 第二种是自己写\n", | |||
" # _init_param_map(pred='output', target='label')\n", | |||
" # 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是get_loss\n", | |||
" # 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它loss参数不要传入。如果传入(pred=None, target=None)\n", | |||
" # 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。\n", | |||
" # 但这个参数不是必须要调用的。\n", | |||
"\n", | |||
" def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的\n", | |||
" # 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算loss所必须的key等。检查通过,则调用get_loss\n", | |||
" # 方法。\n", | |||
" fast_param = self._fast_param_map(predict_dict, target_dict):\n", | |||
" if fast_param:\n", | |||
" return self.get_loss(**fast_param)\n", | |||
" # 如果没有fast_param则通过匹配参数然后调用get_loss完成\n", | |||
" xxxx\n", | |||
" return loss # 返回为Tensor的loss\n", | |||
" def _fast_param_map(self, pred_dict, target_dict):\n", | |||
" # 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过\"键映射\",比如计算loss时,pred_dict只有一个元素,\n", | |||
" # target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算loss, 基类判断了这种情况(可能还有其它无歧义的情况)。\n", | |||
" # 即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误\"键映射\"的情况也可以直接计算loss。\n", | |||
" # 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败,\n", | |||
" # __call__方法会继续执行。\n", | |||
"\n", | |||
" def get_loss(self, *args, **kwargs):\n", | |||
" # 这个是一定需要实现的,计算loss的地方。\n", | |||
" # (1) get_loss中一定不能包含*arg这种参数形式。\n", | |||
" # (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数\n", | |||
" raise NotImplementedError\n", | |||
"\n", | |||
"# 下面使用L1Loss举例\n", | |||
"class L1Loss(LossBase): # 继承LossBase\n", | |||
" # 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与get_loss需要参数名是对应的\n", | |||
" def __init__(self, pred=None, target=None): \n", | |||
" super(L1Loss, self).__init__()\n", | |||
" # 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n", | |||
" # “键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则\n", | |||
" # 不要将threshold传入_init_param_map.\n", | |||
" self._init_param_map(pred=pred, target=target)\n", | |||
"\n", | |||
" def get_loss(self, pred, target):\n", | |||
" # 这里'pred', 'target'必须和初始化的映射是一致的。\n", | |||
" return F.l1_loss(input=pred, target=target) #直接返回一个loss即可" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"### Metric: 根据Model.forward()或者Model.predict()的结果计算metric \n", | |||
"metric的设计和loss的设计类似。都是传入pred_dict与target_dict进行计算。但是metric的pred_dict来源可能是Model.forward的返回值, 也可能是Model.predict(如果Model具有predict方法则会调用predict方法)的返回值,下面统一用pred_dict代替。 \n", | |||
"(1) 这里的\"键映射\"与loss的\"键映射\"是类似的。举例来说,若Metric(pred='output', target='label'),则使用'output'到pred_dict和target_dict中寻找pred, 用'label'寻找target。 \n", | |||
"(2) 如何创建一个自己的Metric方法 \n", | |||
"  Metric与loss的计算不同在于,Metric的计算有两个步骤。 \n", | |||
"  (2.1) <b>每个batch的输出</b>都会调用Metric的\\__call__(pred_dict, target_dict)方法,而\\__call__方法会调用evaluate()(需要实现)方法。 \n", | |||
"  (2.2) 在所有batch传入之后,调用Metric的get_metric()方法得到最终的metric值。 \n", | |||
"  所以Metric在调用evaluate方法时,根据拿到的数据: pred_dict与batch_y, 改变自己的状态(比如累加正确的次数,总的sample数等)。在调用get_metric()的时候给出一个最终计算结果。 \n", | |||
"所有的Metric必须继承自fastNLP.core.metrics.MetricBase. 例子见下一个cell \n", | |||
"(3) 尽量不要复写\\__call__(), _init_param_map()方法。\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "raw", | |||
"metadata": {}, | |||
"source": [ | |||
"MetricBase: \n", | |||
" def __init__(self):\n", | |||
" self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好\n", | |||
" self._checked = False # 这个参数可以忽略\n", | |||
"\n", | |||
" def _init_param_map(self, key_map=None, **kwargs):\n", | |||
" # 这个函数是用于注册Metric的“键映射”,有两种传值方法,\n", | |||
" # 第一种是通过key_map传入dict,取值是用value到forward和batch_y取\n", | |||
" # key_map = {'pred': 'output', 'target': 'label'} \n", | |||
" # 第二种是自己写(建议使用改种方式)\n", | |||
" # _init_param_map(pred='output', target='label')\n", | |||
" # 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是evaluate()\n", | |||
" # 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它evaluate参数不要传入。如果传入(pred=None, target=None)\n", | |||
" # 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。\n", | |||
" # 但这个参数不是必须要调用的。\n", | |||
"\n", | |||
" def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的\n", | |||
" # 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算evaluate所必须的key等。检查通过,则调用\n", | |||
" # evaluate方法。\n", | |||
" fast_param = self._fast_param_map(predict_dict, target_dict):\n", | |||
" if fast_param:\n", | |||
" return self.evaluate(**fast_param)\n", | |||
" # 如果没有fast_param则通过匹配参数然后调用get_loss完成\n", | |||
" xxxx\n", | |||
"\n", | |||
" def _fast_param_map(self, pred_dict, target_dict):\n", | |||
" # 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过\"键映射\",比如evaluate时,pred_dict只有一个元素,\n", | |||
" # target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算metric, 基类判断了这种情况(可能还有其它无歧义的\n", | |||
" # 情况)。即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误\"键映射\"的情况也可以直接计算metric。\n", | |||
" # 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败,\n", | |||
" # __call__方法会继续尝试匹配。\n", | |||
"\n", | |||
" def evaluate(self, *args, **kwargs):\n", | |||
" # 这个是一定需要实现的,累加metric状态\n", | |||
" # (1) evaluate()中一定不能包含*arg这种参数形式。\n", | |||
" # (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数\n", | |||
" raise NotImplementedError\n", | |||
"\n", | |||
" def get_metric(self, reset=True):\n", | |||
" # 这是一定需要实现的,获取最终的metric。返回值必须是一个dict。会在所有batch传入之后调用\n", | |||
" raise NotImplemented\n", | |||
"\n", | |||
"下面使用AccuracyMetric举例\n", | |||
"class AccuracyMetric(MetricBase): # MetricBase\n", | |||
" # 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与evaluate()需要参数名是对应的\n", | |||
" def __init__(self, pred=None, target=None): \n", | |||
" super(AccuracyMetric, self).__init__()\n", | |||
" # 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于\n", | |||
" # “键映射\"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则\n", | |||
" # 不要将threshold传入_init_param_map.\n", | |||
" self._init_param_map(pred=pred, target=target)\n", | |||
"\n", | |||
" self.total = 0 # 用于累加一共有多少sample\n", | |||
" self.corr = 0 # 用于累加一共有多少正确的sample\n", | |||
"\n", | |||
" def evaluate(self, pred, target):\n", | |||
" # 对pred和target做一些基本的判断或者预处理等\n", | |||
" if pred.size()==target.size() and len(pred.size())=1: #如果pred已经做了argmax\n", | |||
" pass\n", | |||
" elif len(pred.size())==2 and len(target.size())==1: # pred还没有进行argmax\n", | |||
" pred = pred.argmax(dim=1)\n", | |||
" else:\n", | |||
" raise ValueError(\"The shape of pred and target should be ((B, n_classes), (B, )) or (\"\n", | |||
" \"(B,),(B,)).\")\n", | |||
" assert pred.size(0)==target.size(0), \"Mismatch batch size.\"\n", | |||
" # 进行相应的累加\n", | |||
" self.total += pred.size(0)\n", | |||
" self.corr += torch.sum(torch.eq(pred, target).float()).item()\n", | |||
"\n", | |||
" def get_metric(self, reset=True):\n", | |||
" # reset用于指示是否清空累加信息。默认为True\n", | |||
" # 这个函数需要返回dict,可以包含多个metric。\n", | |||
" metric = {}\n", | |||
" metric['acc'] = self.corr/self.total\n", | |||
" if reset:\n", | |||
" self.total = 0\n", | |||
" self.corr = 0\n", | |||
" return metric" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Tester: 用于做evaluation,应该不需要更改\n", | |||
"重要的初始化参数有,data, model, metric \n", | |||
"比较重要的function是test() \n", | |||
"test中的运行过程 \n", | |||
"  predict_func = 如果有model.predict则为model.predict, 否则是model.forward \n", | |||
"  for batch_x, batch_y in batch: \n", | |||
"    # (1) 同步数据与model \n", | |||
"    # (2) 根据predict_func的参数从batch_x中取出数据传入到predict_func中,得到结果pred_dict \n", | |||
"    # (3) 调用metric(pred_dict, batch_y \n", | |||
"    #(4) 当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 \n", | |||
"  metric.get_metric()" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"source": [ | |||
"#### Trainer: 对训练过程的封装。 \n", | |||
"里面比较重要的function是train() \n", | |||
"train()中的运行过程 \n", | |||
"  # (1) 创建batch \n", | |||
"  batch = Batch(dataset, batch_size, sampler=sampler) \n", | |||
"  for batch_x, batch_y in batch: \n", | |||
"    \"\"\" \n", | |||
"    batch_x,batch_y都是dict。batch_x是DataSet中被设置为input的field;batch_y是DataSet中被设置为target的field。 \n", | |||
"    两个dict中的key就是DataSet中的key,value会根据情况做好padding的tensor。 \n", | |||
"    \"\"\" \n", | |||
"    # (2)会将batch_x, batch_y中tensor移动到model所在的device \n", | |||
"    # (3)根据model.forward的参数列表, 从batch_x中取出需要传递给forward的数据。 \n", | |||
"    # (4)获取model.forward的输出结果pred_dict,并与batch_y一起传递给loss函数, 求得loss \n", | |||
"    # (5)对loss进行反向梯度并更新参数 \n", | |||
"  # (6) 如果有验证集,则需要做验证 \n", | |||
"  tester = Tester(model, dev_data,metric) \n", | |||
"  eval_results = tester.test() \n", | |||
"  # (7) 如果eval_results是当前的最佳结果,则保存模型。 " | |||
] | |||
}, | |||
{ | |||
"cell_type": "raw", | |||
"metadata": {}, | |||
"source": [ | |||
"除了以上的内容,\n", | |||
"Trainer中还提供了\"预跑\"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行\"预跑\"。\n", | |||
"check_code_level=0,1,2代表不同的提醒级别。目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。\n", | |||
"0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行\n", | |||
"\"预跑\"的主要目的有两个: (1) 防止train完了之后进行evaluation的时候出现错误。之前的train就白费了\n", | |||
" (2) 由于存在\"键映射\",直接运行导致的报错可能不太容易debug,通过\"预跑\"过程的报错会有一些debug提示\n", | |||
"\"预跑\"会进行以下的操作:(1) 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。\n", | |||
" (2) 将Model.foward的输出pred_dict与batch_y输入到loss中, 并尝试backward. 不会更新参数,而且grad会被清零\n", | |||
" 如果传入了dev_data,还将进行metric的测试\n", | |||
" (3) 创建Tester,并传入少量数据,检测是否可以正常运行\n", | |||
"\"预跑\"操作是在Trainer初始化的时候执行的。\n", | |||
"正常情况下,应该不需要改动\"预跑\"的代码。但如果你遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 3", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.6.7" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 2 | |||
} |
@@ -0,0 +1,283 @@ | |||
# fastNLP开发者指南 | |||
#### 本教程涉及以下类: | |||
- DataSet | |||
- Sampler | |||
- Batch | |||
- Model | |||
- Loss | |||
- Metric | |||
- Trainer | |||
- Tester | |||
#### DataSet: 用于承载数据。 | |||
1. DataSet里面每个元素只能是以下的三类`np.float64`, `np.int64`, `np.str`。如果传入的数据是`int`则被转换为`np.int64`, `float`被转为`np.float64`。 | |||
2. DataSet可以将field设置为input或者target。其中被设置为input的field会被传递给Model.forward, 这个过程中我们是通过键匹配完成传递的。举例来说,假设DataSet中有'x1', 'x2', 'x3'被设置为了input,而 | |||
- 函数是Model.forward(self, x1, x3), 那么DataSet中'x1', 'x3'会被传递给forward函数。多余的'x2'会被忽略 | |||
- 函数是Model.forward(self, x1, x4), 这里多需要了一个'x4', 但是DataSet的input field中没有这个field,会报错。 | |||
- 函数是Model.forward(self, x1, **kwargs), 会把'x1', 'x2', 'x3'都传入。但如果是Model.forward(self, x4, **kwargs)就会发生报错,因为没有'x4'。 | |||
3. 对于设置为target的field的名称,我们建议取名为'target'(如果只有一个需要predict的值),但是不强制。后面会讲为什么target可以不强制。 | |||
DataSet应该是不需要单独再开发的,如果有不能满足的场景,请在开发群提出或者github提交issue。 | |||
#### Sampler: 给定一个DataSet,返回一个序号的list,Batch按照这个list输出数据。 | |||
Sampler需要继承fastNLP.core.sampler.BaseSampler | |||
```python | |||
class BaseSampler(object): | |||
"""The base class of all samplers. | |||
Sub-classes must implement the __call__ method. | |||
__call__ takes a DataSet object and returns a list of int - the sampling indices. | |||
""" | |||
def __call__(self, *args, **kwargs): | |||
raise NotImplementedError | |||
# 子类需要复写__call__方法。这个函数只能有一个必选参数, 且必须是DataSet类别, 否则Trainer没法调 | |||
class SonSampler(BaseSampler): | |||
def __init__(self, xxx): | |||
# 可以实现init也不可以不实现。 | |||
pass | |||
def __call__(self, data_set): | |||
pass | |||
``` | |||
#### Batch: 将DataSet中设置为input和target的field取出来构成batch_x, batch_y | |||
并且根据情况(主要根据数据类型能不能转为Tensor)将数据转换为pytorch的Tensor。batch中sample的取出顺序是由Sampler决定的。 | |||
Sampler是传入一个DataSet,返回一个与DataSet等长的序号list,Batch一次会取出batch_size个sample(最后一个batch可能数量不足batch_size个)。 | |||
举例: | |||
1. SequentialSampler是顺序采样 | |||
假设传入的DataSet长度是100, SequentialSampler返回的序号list就是[0, 1, ...,98, 99]. batch_size如果被设置为4,那么第一个batch所获取的instance就是[0, 1, 2, 3]这四个instance. 第二个batch所获取instace就是[4, 5, 6, 7], ...直到采完所有的sample。 | |||
2. RandomSampler是随机采样 | |||
假设传入的DataSet长度是100, RandomSampler返回的序号list可能是[0, 99, 20, 5, 3, 1, ...]. 依次按照batch_size的大小取出sample。 | |||
Batch应该不需要继承与开发,如果你有特殊需求请在开发群里提出。 | |||
#### Model:用户自定的Model | |||
必须是nn.Module的子类 | |||
1. 必须实现forward方法,并且forward方法不能出现*arg这种参数. 例如 | |||
```python | |||
def forward(self, word_seq, *args): #这是不允许的. | |||
# ... | |||
pass | |||
``` | |||
返回值必须是dict的 | |||
```python | |||
def forward(self, word_seq, seq_lens): | |||
xxx = "xxx" | |||
return {'pred': xxx} #return的值必须是dict的。里面的预测的key推荐使用pred,但是不做强制限制。输出元素数目不限。 | |||
``` | |||
2. 如果实现了predict方法,在做evaluation的时候将调用predict方法而不是forward。如果没有predict方法,则在evaluation时调用forward方法。predict方法也不能使用*args这种参数形式,同时结果也必须返回一个dict,同样推荐key为'pred'。 | |||
#### Loss: 根据model.forward()返回的prediction(是一个dict)和batch_y计算相应的loss | |||
1. 先介绍"键映射"。 如在DataSet, Model一节所看见的那样,fastNLP并不限制Model.forward()的返回值,也不限制DataSet中target field的key。计算的loss的时候,怎么才能知道从哪里取值呢? | |||
这里以CrossEntropyLoss为例,一般情况下, 计算CrossEntropy需要prediction和target两个值。而在CrossEntropyLoss初始化时可以传入两个参数(pred=None, target=None), 这两个参数接受的类型是str,假设(pred='output', target='label'),那么CrossEntropyLoss会使用'output'这个key在forward的output与batch_y中寻找值;'label'也是在forward的output与batch_y中寻找值。注意这里pred或target的来源并不一定非要来自于model.forward与batch_y,也可以只来自于forward的结果。 | |||
2. 如何创建一个自己的loss | |||
- 使用fastNLP.LossInForward, 在model.forward()的结果中包含一个为loss的key。 | |||
- trainer中使用loss(假设loss=CrossEntropyLoss())的时候其实是 | |||
los = loss(prediction, batch_y),即直接调用的是`loss.__call__()`方法,但是CrossEntropyLoss里面并没有自己实现`__call__`方法,这是因为`__call__`在LossBase中实现了。所有的loss必须继承fastNLP.core.loss.LossBase, 下面先说一下LossBase的几个方法,见下一节。 | |||
3. 尽量不要复写`__call__()`, `_init_param_map()`方法。 | |||
```python | |||
class LossBase(): | |||
def __init__(self): | |||
self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好 | |||
self._checked = False # 这个参数可以忽略 | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
# 这个函数是用于注册Loss的“键映射”,有两种传值方法, | |||
# 第一种是通过key_map传入dict,取值是用value到forward和batch_y取 | |||
# key_map = {'pred': 'output', 'target': 'label'} | |||
# 第二种是自己写 | |||
# _init_param_map(pred='output', target='label') | |||
# 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是get_loss | |||
# 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它loss参数不要传入。如果传入(pred=None, target=None) | |||
# 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。 | |||
# 但这个参数不是必须要调用的。 | |||
def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的 | |||
# 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算loss所必须的key等。检查通过,则调用get_loss | |||
# 方法。 | |||
fast_param = self._fast_param_map(predict_dict, target_dict): | |||
if fast_param: | |||
return self.get_loss(**fast_param) | |||
# 如果没有fast_param则通过匹配参数然后调用get_loss完成 | |||
xxxx | |||
return loss # 返回为Tensor的loss | |||
def _fast_param_map(self, pred_dict, target_dict): | |||
# 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过"键映射",比如计算loss时,pred_dict只有一个元素, | |||
# target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算loss, 基类判断了这种情况(可能还有其它无歧义的情况)。 | |||
# 即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误"键映射"的情况也可以直接计算loss。 | |||
# 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败, | |||
# __call__方法会继续执行。 | |||
def get_loss(self, *args, **kwargs): | |||
# 这个是一定需要实现的,计算loss的地方。 | |||
# (1) get_loss中一定不能包含*arg这种参数形式。 | |||
# (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数 | |||
raise NotImplementedError | |||
# 下面使用L1Loss举例 | |||
class L1Loss(LossBase): # 继承LossBase | |||
# 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与get_loss需要参数名是对应的 | |||
def __init__(self, pred=None, target=None): | |||
super(L1Loss, self).__init__() | |||
# 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于 | |||
# “键映射"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则 | |||
# 不要将threshold传入_init_param_map. | |||
self._init_param_map(pred=pred, target=target) | |||
def get_loss(self, pred, target): | |||
# 这里'pred', 'target'必须和初始化的映射是一致的。 | |||
return F.l1_loss(input=pred, target=target) #直接返回一个loss即可 | |||
``` | |||
### Metric: 根据Model.forward()或者Model.predict()的结果计算metric | |||
metric的设计和loss的设计类似。都是传入pred_dict与target_dict进行计算。但是metric的pred_dict来源可能是Model.forward的返回值, 也可能是Model.predict(如果Model具有predict方法则会调用predict方法)的返回值,下面统一用pred_dict代替。 | |||
1. 这里的"键映射"与loss的"键映射"是类似的。举例来说,若Metric(pred='output', target='label'),则使用'output'到pred_dict和target_dict中寻找pred, 用'label'寻找target。 | |||
2. 如何创建一个自己的Metric方法 | |||
Metric与loss的计算不同在于,Metric的计算有两个步骤。 | |||
- **每个batch的输出**都会调用Metric的``__call__(pred_dict, target_dict)``方法,而``__call__``方法会调用evaluate()(需要实现)方法。 | |||
- 在所有batch传入之后,调用Metric的get_metric()方法得到最终的metric值。 | |||
- 所以Metric在调用evaluate方法时,根据拿到的数据: pred_dict与batch_y, 改变自己的状态(比如累加正确的次数,总的sample数等)。在调用get_metric()的时候给出一个最终计算结果。 | |||
所有的Metric必须继承自fastNLP.core.metrics.MetricBase. 例子见下一个cell | |||
3. 尽量不要复写``__call__()``,``_init_param_map()``方法。 | |||
```python | |||
class MetricBase: | |||
def __init__(self): | |||
self.param_map = {} # 一般情况下也不需要自己创建。调用_init_param_map()更好 | |||
self._checked = False # 这个参数可以忽略 | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
# 这个函数是用于注册Metric的“键映射”,有两种传值方法, | |||
# 第一种是通过key_map传入dict,取值是用value到forward和batch_y取 | |||
# key_map = {'pred': 'output', 'target': 'label'} | |||
# 第二种是自己写(建议使用改种方式) | |||
# _init_param_map(pred='output', target='label') | |||
# 为什么会提供这么一个方法?通过调用这个方法会自动注册param_map,并会做一些检查,防止出现传入的key其实并不是evaluate() | |||
# 的一个参数。注意传入这个方法的参数必须都是需要做键映射的内容,其它evaluate参数不要传入。如果传入(pred=None, target=None) | |||
# 则__call__()会到pred_dict与target_dict去寻找key为'pred'和'target'的值。 | |||
# 但这个参数不是必须要调用的。 | |||
pass | |||
def __call__(self, pred_dict, target_dict, check=False): # check=False忽略这个参数,之后应该会被删除的 | |||
# 这个函数主要会做一些check的工作,比如pred_dict与target_dict中是否包含了计算evaluate所必须的key等。检查通过,则调用 | |||
# evaluate方法。 | |||
fast_param = self._fast_param_map(predict_dict, target_dict): | |||
if fast_param: | |||
return self.evaluate(**fast_param) | |||
# 如果没有fast_param则通过匹配参数然后调用get_loss完成 | |||
# xxxx | |||
def _fast_param_map(self, pred_dict, target_dict): | |||
# 这是一种快速计算loss的机制,因为在很多情况下其实都不需要通过"键映射",比如evaluate时,pred_dict只有一个元素, | |||
# target_dict也只有一个元素,那么无歧义地就可以把预测值与实际值用于计算metric, 基类判断了这种情况(可能还有其它无歧义的 | |||
# 情况)。即_fast_param_map成功的话,就不需要使用键映射,这样即使在没有传递或者传递错误"键映射"的情况也可以直接计算metric。 | |||
# 返回值是一个dict, 如果匹配成功,应该返回类似{'pred':value, 'target': value}的结果;如果dict为空则说明匹配失败, | |||
# __call__方法会继续尝试匹配。 | |||
pass | |||
def evaluate(self, *args, **kwargs): | |||
# 这个是一定需要实现的,累加metric状态 | |||
# (1) evaluate()中一定不能包含*arg这种参数形式。 | |||
# (2) 如果包含**kwargs这种参数,这会将pred_dict与target_dict中所有参数传入。但是建议不要用这个参数 | |||
raise NotImplementedError | |||
def get_metric(self, reset=True): | |||
# 这是一定需要实现的,获取最终的metric。返回值必须是一个dict。会在所有batch传入之后调用 | |||
raise NotImplementedError | |||
# 下面使用AccuracyMetric举例 | |||
class AccuracyMetric(MetricBase): # MetricBase | |||
# 初始化需要映射的值,这里需要映射的值'pred', 'target'必须与evaluate()需要参数名是对应的 | |||
def __init__(self, pred=None, target=None): | |||
super(AccuracyMetric, self).__init__() | |||
# 这里传入_init_param_map以使得pred和target被正确注册,但这一步不是必须的, 建议调用。传入_init_param_map的是用于 | |||
# “键映射"的键值对。假设初始化__init__(pred=None, target=None, threshold=0.1)中threshold是用于控制loss计算的,则 | |||
# 不要将threshold传入_init_param_map. | |||
self._init_param_map(pred=pred, target=target) | |||
self.total = 0 # 用于累加一共有多少sample | |||
self.corr = 0 # 用于累加一共有多少正确的sample | |||
def evaluate(self, pred, target): | |||
# 对pred和target做一些基本的判断或者预处理等 | |||
if pred.size()==target.size() and len(pred.size())=1: #如果pred已经做了argmax | |||
pass | |||
elif len(pred.size())==2 and len(target.size())==1: # pred还没有进行argmax | |||
pred = pred.argmax(dim=1) | |||
else: | |||
raise ValueError("The shape of pred and target should be ((B, n_classes), (B, )) or (" | |||
"(B,),(B,)).") | |||
assert pred.size(0)==target.size(0), "Mismatch batch size." | |||
# 进行相应的累加 | |||
self.total += pred.size(0) | |||
self.corr += torch.sum(torch.eq(pred, target).float()).item() | |||
def get_metric(self, reset=True): | |||
# reset用于指示是否清空累加信息。默认为True | |||
# 这个函数需要返回dict,可以包含多个metric。 | |||
metric = {} | |||
metric['acc'] = self.corr/self.total | |||
if reset: | |||
self.total = 0 | |||
self.corr = 0 | |||
return metric | |||
``` | |||
#### Tester: 用于做evaluation,应该不需要更改 | |||
重要的初始化参数有data, model, metric;比较重要的function是test()。 | |||
test中的运行过程 | |||
``` | |||
predict_func = 如果有model.predict则为model.predict, 否则是model.forward | |||
for batch_x, batch_y in batch: | |||
# (1) 同步数据与model | |||
# (2) 根据predict_func的参数从batch_x中取出数据传入到predict_func中,得到结果pred_dict | |||
# (3) 调用metric(pred_dict, batch_y | |||
# (4) 当所有batch都运行完毕,会调用metric的get_metric方法,并且以返回的值作为evaluation的结果 | |||
metric.get_metric() | |||
``` | |||
#### Trainer: 对训练过程的封装。 | |||
里面比较重要的function是train() | |||
train()中的运行过程 | |||
``` | |||
(1) 创建batch | |||
batch = Batch(dataset, batch_size, sampler=sampler) | |||
for batch_x, batch_y in batch: | |||
# ... | |||
batch_x,batch_y都是dict。batch_x是DataSet中被设置为input的field;batch_y是DataSet中被设置为target的field。 | |||
两个dict中的key就是DataSet中的key,value会根据情况做好padding的tensor。 | |||
(2)会将batch_x, batch_y中tensor移动到model所在的device | |||
(3)根据model.forward的参数列表, 从batch_x中取出需要传递给forward的数据。 | |||
(4)获取model.forward的输出结果pred_dict,并与batch_y一起传递给loss函数, 求得loss | |||
(5)对loss进行反向梯度并更新参数 | |||
(6) 如果有验证集,则需要做验证 | |||
tester = Tester(model, dev_data,metric) | |||
eval_results = tester.test() | |||
(7) 如果eval_results是当前的最佳结果,则保存模型。 | |||
``` | |||
#### 其他 | |||
Trainer中还提供了"预跑"的功能。该功能通过check_code_level管理,如果check_code_level为-1,则不进行"预跑"。 | |||
check_code_level=0,1,2代表不同的提醒级别。 | |||
目前不同提醒级别对应的是对DataSet中设置为input或target但又没有使用的field的提醒级别。 | |||
0是忽略(默认);1是会warning发生了未使用field的情况;2是出现了unused会直接报错并退出运行 | |||
"预跑"的主要目的有两个: | |||
- 防止train完了之后进行evaluation的时候出现错误。之前的train就白费了 | |||
- 由于存在"键映射",直接运行导致的报错可能不太容易debug,通过"预跑"过程的报错会有一些debug提示 | |||
"预跑"会进行以下的操作: | |||
- 使用很小的batch_size, 检查batch_x中是否包含Model.forward所需要的参数。只会运行两个循环。 | |||
- 将Model.foward的输出pred_dict与batch_y输入到loss中, 并尝试backward. 不会更新参数,而且grad会被清零 | |||
如果传入了dev_data,还将进行metric的测试 | |||
- 创建Tester,并传入少量数据,检测是否可以正常运行 | |||
"预跑"操作是在Trainer初始化的时候执行的。 | |||
正常情况下,应该不需要改动"预跑"的代码。但如果你遇到bug或者有什么好的建议,欢迎在开发群或者github提交issue。 | |||