diff --git a/docs/source/_static/notebooks/文本分类.ipynb b/docs/source/_static/notebooks/文本分类.ipynb index d18301ec..66439a76 100644 --- a/docs/source/_static/notebooks/文本分类.ipynb +++ b/docs/source/_static/notebooks/文本分类.ipynb @@ -46,10 +46,8 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from fastNLP.io import ChnSentiCorpLoader\n", @@ -68,22 +66,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "In total 3 datasets:\n", - "\tdev has 1200 instances.\n", - "\ttrain has 9600 instances.\n", - "\ttest has 1200 instances.\n", - "In total 0 vocabs:\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(data_bundle)" ] @@ -97,20 +82,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", - "'target': 1 type=str},\n", - "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", - "'target': 1 type=str})\n" - ] - } - ], + "outputs": [], "source": [ "print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample" ] @@ -127,10 +101,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from fastNLP.io import ChnSentiCorpPipe\n", @@ -141,24 +113,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "In total 3 datasets:\n", - "\tdev has 1200 instances.\n", - "\ttrain has 9600 instances.\n", - "\ttest has 1200 instances.\n", - "In total 2 vocabs:\n", - "\tchars has 4409 entries.\n", - "\ttarget has 2 entries.\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(data_bundle) # 打印data_bundle,查看其变化" ] @@ -172,24 +129,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", - "'target': 1 type=int,\n", - "'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n", - "'seq_len': 106 type=int},\n", - "{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", - "'target': 1 type=int,\n", - "'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n", - "'seq_len': 56 type=int})\n" - ] - } - ], + "outputs": [], "source": [ "print(data_bundle.get_dataset('train')[:2])" ] @@ -203,17 +145,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Vocabulary(['选', '择', '珠', '江', '花']...)\n" - ] - } - ], + "outputs": [], "source": [ "char_vocab = data_bundle.get_vocab('chars')\n", "print(char_vocab)" @@ -228,18 +162,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "'选'的index是338\n", - "index:338对应的汉字是选\n" - ] - } - ], + "outputs": [], "source": [ "index = char_vocab.to_index('选')\n", "print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n", @@ -256,17 +181,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 4321 out of 4409 words in the pre-training embedding.\n" - ] - } - ], + "outputs": [], "source": [ "from fastNLP.embeddings import StaticEmbedding\n", "\n", @@ -283,10 +200,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": true - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from torch import nn\n", @@ -329,288 +244,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "input fields after batch(if batch size is 2):\n", - "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", - "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "target fields after batch(if batch size is 2):\n", - "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\n", - "Evaluate data in 0.01 seconds!\n", - "training epochs started 2019-09-03-23-57-10\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.43 seconds!\n", - "\r", - "Evaluation on dev at Epoch 1/10. Step:300/3000: \n", - "\r", - "AccuracyMetric: acc=0.81\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.44 seconds!\n", - "\r", - "Evaluation on dev at Epoch 2/10. Step:600/3000: \n", - "\r", - "AccuracyMetric: acc=0.8675\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.44 seconds!\n", - "\r", - "Evaluation on dev at Epoch 3/10. Step:900/3000: \n", - "\r", - "AccuracyMetric: acc=0.878333\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.43 seconds!\n", - "\r", - "Evaluation on dev at Epoch 4/10. Step:1200/3000: \n", - "\r", - "AccuracyMetric: acc=0.873333\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.44 seconds!\n", - "\r", - "Evaluation on dev at Epoch 5/10. Step:1500/3000: \n", - "\r", - "AccuracyMetric: acc=0.878333\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.42 seconds!\n", - "\r", - "Evaluation on dev at Epoch 6/10. Step:1800/3000: \n", - "\r", - "AccuracyMetric: acc=0.895833\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.44 seconds!\n", - "\r", - "Evaluation on dev at Epoch 7/10. Step:2100/3000: \n", - "\r", - "AccuracyMetric: acc=0.8975\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.43 seconds!\n", - "\r", - "Evaluation on dev at Epoch 8/10. Step:2400/3000: \n", - "\r", - "AccuracyMetric: acc=0.894167\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.48 seconds!\n", - "\r", - "Evaluation on dev at Epoch 9/10. Step:2700/3000: \n", - "\r", - "AccuracyMetric: acc=0.8875\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.43 seconds!\n", - "\r", - "Evaluation on dev at Epoch 10/10. Step:3000/3000: \n", - "\r", - "AccuracyMetric: acc=0.895833\n", - "\n", - "\r\n", - "In Epoch:7/Step:2100, got best dev performance:\n", - "AccuracyMetric: acc=0.8975\n", - "Reloaded the best model.\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 0.34 seconds!\n", - "[tester] \n", - "AccuracyMetric: acc=0.8975\n" - ] - }, - { - "data": { - "text/plain": [ - "{'AccuracyMetric': {'acc': 0.8975}}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from fastNLP import Trainer\n", "from fastNLP import CrossEntropyLoss\n", @@ -643,139 +279,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n", - "Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n", - "Start to generating word pieces for word.\n", - "Found(Or segment into word pieces) 4286 words out of 4409.\n", - "input fields after batch(if batch size is 2):\n", - "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", - "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "target fields after batch(if batch size is 2):\n", - "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", - "\n", - "Evaluate data in 0.05 seconds!\n", - "training epochs started 2019-09-04-00-02-37\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 15.89 seconds!\n", - "\r", - "Evaluation on dev at Epoch 1/3. Step:1200/3600: \n", - "\r", - "AccuracyMetric: acc=0.9\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 15.92 seconds!\n", - "\r", - "Evaluation on dev at Epoch 2/3. Step:2400/3600: \n", - "\r", - "AccuracyMetric: acc=0.904167\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 15.91 seconds!\n", - "\r", - "Evaluation on dev at Epoch 3/3. Step:3600/3600: \n", - "\r", - "AccuracyMetric: acc=0.918333\n", - "\n", - "\r\n", - "In Epoch:3/Step:3600, got best dev performance:\n", - "AccuracyMetric: acc=0.918333\n", - "Reloaded the best model.\n", - "Performance on test is:\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\r", - "Evaluate data in 29.24 seconds!\n", - "[tester] \n", - "AccuracyMetric: acc=0.919167\n" - ] - }, - { - "data": { - "text/plain": [ - "{'AccuracyMetric': {'acc': 0.919167}}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# 只需要切换一下Embedding即可\n", "from fastNLP.embeddings import BertEmbedding\n", @@ -840,9 +346,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "from fastNLP.io import ChnSentiCorpLoader\n", @@ -861,9 +365,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -912,15 +414,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "from fastHan import FastHan\n", "from fastNLP import Vocabulary\n", "\n", "model=FastHan()\n", + "# model.set_device('cuda')\n", "\n", "# 定义分词处理操作\n", "def word_seg(ins):\n", @@ -933,6 +434,8 @@ " # apply函数将对内部的instance依次执行word_seg操作,并把其返回值放入到raw_words这个field\n", " ds.apply(word_seg, new_field_name='raw_words')\n", " # 除了apply函数,fastNLP还支持apply_field, apply_more(可同时创建多个field)等操作\n", + " # 同时我们增加一个seq_len的field\n", + " ds.add_seq_len('raw_words')\n", "\n", "vocab = Vocabulary()\n", "\n", @@ -961,11 +464,14 @@ "# 我们把words和target分别设置为input和target,这样它们才会在训练循环中被取出并自动padding, 有关这部分更多的内容参考\n", "# http://www.fastnlp.top/docs/fastNLP/tutorials/tutorial_6_datasetiter.html\n", "data_bundle.set_target('target')\n", - "data_bundle.set_input('words') # DataSet也有这两个接口\n", + "data_bundle.set_input('words', 'seq_len') # DataSet也有这两个接口\n", "# 如果某些field,您希望它被设置为target或者input,但是不希望fastNLP自动padding或需要使用特定的padding方式,请参考\n", "# http://www.fastnlp.top/docs/fastNLP/fastNLP.core.dataset.html\n", "\n", - "print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容" + "print(data_bundle.get_dataset('train')[:2]) # 我们可以看一下当前dataset的内容\n", + "\n", + "# 由于之后需要使用之前定义的BiLSTMMaxPoolCls模型,所以需要将words这个field修改为chars(因为该模型的forward接受chars参数)\n", + "data_bundle.rename_field('words', 'chars')" ] }, { @@ -985,9 +491,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "from fastNLP.embeddings import StaticEmbedding\n", @@ -999,11 +503,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ + "from fastNLP import Trainer\n", + "from fastNLP import CrossEntropyLoss\n", + "from torch.optim import Adam\n", + "from fastNLP import AccuracyMetric\n", + "\n", "# 初始化模型\n", "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))\n", "\n", @@ -1024,6 +531,13 @@ "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", "tester.test()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1042,7 +556,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.10" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/docs/source/tutorials/文本分类.rst b/docs/source/tutorials/文本分类.rst index 4b882cde..73686916 100644 --- a/docs/source/tutorials/文本分类.rst +++ b/docs/source/tutorials/文本分类.rst @@ -447,6 +447,7 @@ PS: 基于词进行文本分类 from fastNLP import Vocabulary model=FastHan() + # model.set_device('cuda') # 可以注视掉这一行增加速度 # 定义分词处理操作 def word_seg(ins):