Browse Source

update tutorial

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
d6072ba1d3
2 changed files with 51 additions and 536 deletions
  1. +50
    -536
      docs/source/_static/notebooks/文本分类.ipynb
  2. +1
    -0
      docs/source/tutorials/文本分类.rst

+ 50
- 536
docs/source/_static/notebooks/文本分类.ipynb View File

@@ -46,10 +46,8 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import ChnSentiCorpLoader\n",
@@ -68,22 +66,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\tdev has 1200 instances.\n",
"\ttrain has 9600 instances.\n",
"\ttest has 1200 instances.\n",
"In total 0 vocabs:\n",
"\n"
]
}
],
"outputs": [],
"source": [
"print(data_bundle)"
]
@@ -97,20 +82,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
"'target': 1 type=str},\n",
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
"'target': 1 type=str})\n"
]
}
],
"outputs": [],
"source": [
"print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample"
]
@@ -127,10 +101,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import ChnSentiCorpPipe\n",
@@ -141,24 +113,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\tdev has 1200 instances.\n",
"\ttrain has 9600 instances.\n",
"\ttest has 1200 instances.\n",
"In total 2 vocabs:\n",
"\tchars has 4409 entries.\n",
"\ttarget has 2 entries.\n",
"\n"
]
}
],
"outputs": [],
"source": [
"print(data_bundle) # 打印data_bundle,查看其变化"
]
@@ -172,24 +129,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n",
"'target': 1 type=int,\n",
"'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n",
"'seq_len': 106 type=int},\n",
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n",
"'target': 1 type=int,\n",
"'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n",
"'seq_len': 56 type=int})\n"
]
}
],
"outputs": [],
"source": [
"print(data_bundle.get_dataset('train')[:2])"
]
@@ -203,17 +145,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Vocabulary(['选', '择', '珠', '江', '花']...)\n"
]
}
],
"outputs": [],
"source": [
"char_vocab = data_bundle.get_vocab('chars')\n",
"print(char_vocab)"
@@ -228,18 +162,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'选'的index是338\n",
"index:338对应的汉字是选\n"
]
}
],
"outputs": [],
"source": [
"index = char_vocab.to_index('选')\n",
"print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
@@ -256,17 +181,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 4321 out of 4409 words in the pre-training embedding.\n"
]
}
],
"outputs": [],
"source": [
"from fastNLP.embeddings import StaticEmbedding\n",
"\n",
@@ -283,10 +200,8 @@
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
@@ -329,288 +244,9 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"Evaluate data in 0.01 seconds!\n",
"training epochs started 2019-09-03-23-57-10\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/10. Step:300/3000: \n",
"\r",
"AccuracyMetric: acc=0.81\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/10. Step:600/3000: \n",
"\r",
"AccuracyMetric: acc=0.8675\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/10. Step:900/3000: \n",
"\r",
"AccuracyMetric: acc=0.878333\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"\r",
"Evaluation on dev at Epoch 4/10. Step:1200/3000: \n",
"\r",
"AccuracyMetric: acc=0.873333\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 5/10. Step:1500/3000: \n",
"\r",
"AccuracyMetric: acc=0.878333\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.42 seconds!\n",
"\r",
"Evaluation on dev at Epoch 6/10. Step:1800/3000: \n",
"\r",
"AccuracyMetric: acc=0.895833\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.44 seconds!\n",
"\r",
"Evaluation on dev at Epoch 7/10. Step:2100/3000: \n",
"\r",
"AccuracyMetric: acc=0.8975\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"\r",
"Evaluation on dev at Epoch 8/10. Step:2400/3000: \n",
"\r",
"AccuracyMetric: acc=0.894167\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.48 seconds!\n",
"\r",
"Evaluation on dev at Epoch 9/10. Step:2700/3000: \n",
"\r",
"AccuracyMetric: acc=0.8875\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.43 seconds!\n",
"\r",
"Evaluation on dev at Epoch 10/10. Step:3000/3000: \n",
"\r",
"AccuracyMetric: acc=0.895833\n",
"\n",
"\r\n",
"In Epoch:7/Step:2100, got best dev performance:\n",
"AccuracyMetric: acc=0.8975\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 0.34 seconds!\n",
"[tester] \n",
"AccuracyMetric: acc=0.8975\n"
]
},
{
"data": {
"text/plain": [
"{'AccuracyMetric': {'acc': 0.8975}}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from fastNLP import CrossEntropyLoss\n",
@@ -643,139 +279,9 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
"Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
"Start to generating word pieces for word.\n",
"Found(Or segment into word pieces) 4286 words out of 4409.\n",
"input fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"Evaluate data in 0.05 seconds!\n",
"training epochs started 2019-09-04-00-02-37\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 15.89 seconds!\n",
"\r",
"Evaluation on dev at Epoch 1/3. Step:1200/3600: \n",
"\r",
"AccuracyMetric: acc=0.9\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 15.92 seconds!\n",
"\r",
"Evaluation on dev at Epoch 2/3. Step:2400/3600: \n",
"\r",
"AccuracyMetric: acc=0.904167\n",
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 15.91 seconds!\n",
"\r",
"Evaluation on dev at Epoch 3/3. Step:3600/3600: \n",
"\r",
"AccuracyMetric: acc=0.918333\n",
"\n",
"\r\n",
"In Epoch:3/Step:3600, got best dev performance:\n",
"AccuracyMetric: acc=0.918333\n",
"Reloaded the best model.\n",
"Performance on test is:\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 29.24 seconds!\n",
"[tester] \n",
"AccuracyMetric: acc=0.919167\n"
]
},
{
"data": {
"text/plain": [
"{'AccuracyMetric': {'acc': 0.919167}}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# 只需要切换一下Embedding即可\n",
"from fastNLP.embeddings import BertEmbedding\n",
@@ -840,9 +346,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.io import ChnSentiCorpLoader\n",
@@ -861,9 +365,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"import os\n",
@@ -912,15 +414,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from fastHan import FastHan\n",
"from fastNLP import Vocabulary\n",
"\n",
"model=FastHan()\n",
"# model.set_device('cuda')\n",
"\n",
"# 定义分词处理操作\n",
"def word_seg(ins):\n",
@@ -933,6 +434,8 @@
" # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n",
" ds.apply(word_seg, new_field_name='raw_words')\n",
" # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n",
" # 同时我们增加一个seq_len的field\n",
" ds.add_seq_len('raw_words')\n",
"\n",
"vocab = Vocabulary()\n",
"\n",
@@ -961,11 +464,14 @@
"# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n",
"# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n",
"data_bundle.set_target('target')\n",
"data_bundle.set_input('words') # DataSet也有这两个接口\n",
"data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n",
"# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n",
"# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n",
"\n",
"print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容"
"print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n",
"\n",
"# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n",
"data_bundle.rename_field('words', 'chars')"
]
},
{
@@ -985,9 +491,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.embeddings import StaticEmbedding\n",
@@ -999,11 +503,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from fastNLP import CrossEntropyLoss\n",
"from torch.optim import Adam\n",
"from fastNLP import AccuracyMetric\n",
"\n",
"# 初始化模型\n",
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n",
"\n",
@@ -1024,6 +531,13 @@
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
"tester.test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -1042,7 +556,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
"version": "3.6.8"
}
},
"nbformat": 4,


+ 1
- 0
docs/source/tutorials/文本分类.rst View File

@@ -447,6 +447,7 @@ PS: 基于词进行文本分类
from fastNLP import Vocabulary

model=FastHan()
# model.set_device('cuda') # 可以注视掉这一行增加速度

# 定义分词处理操作
def word_seg(ins):


Loading…
Cancel
Save