@@ -0,0 +1,260 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# BertEmbedding的各种用法\n", | |||||
"Bert自从在 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 中被提出后,因其性能卓越受到了极大的关注,在这里我们展示一下在fastNLP中如何使用Bert进行各类任务。其中中文Bert我们使用的模型的权重来自于 中文Bert预训练 。\n", | |||||
"\n", | |||||
"为了方便大家的使用,fastNLP提供了预训练的Embedding权重及数据集的自动下载,支持自动下载的Embedding和数据集见 数据集 。或您可从 使用Embedding模块将文本转成向量 与 使用Loader和Pipe加载并处理数据集 了解更多相关信息\n", | |||||
"\n", | |||||
"\n", | |||||
"下面我们将介绍通过使用Bert来进行文本分类, 中文命名实体识别, 文本匹配, 中文问答。\n", | |||||
"\n", | |||||
"## 1. 使用Bert进行文本分类\n", | |||||
"\n", | |||||
"文本分类是指给定一段文字,判定其所属的类别。例如下面的文本情感分类\n", | |||||
"\n", | |||||
" *1, 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!*\n", | |||||
"\n", | |||||
"这里我们使用fastNLP提供自动下载的微博分类进行测试" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.io import WeiboSenti100kPipe\n", | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"from fastNLP.models import BertForSequenceClassification\n", | |||||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"data_bundle =WeiboSenti100kPipe().process_from_file()\n", | |||||
"data_bundle.rename_field('chars', 'words')\n", | |||||
"\n", | |||||
"# 载入BertEmbedding\n", | |||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n", | |||||
"\n", | |||||
"# 载入模型\n", | |||||
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))\n", | |||||
"\n", | |||||
"# 训练模型\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n", | |||||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n", | |||||
" optimizer=Adam(model_params=model.parameters(), lr=2e-5),\n", | |||||
" loss=CrossEntropyLoss(), device=device,\n", | |||||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n", | |||||
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n", | |||||
"trainer.train()\n", | |||||
"\n", | |||||
"# 测试结果\n", | |||||
"from fastNLP import Tester\n", | |||||
"\n", | |||||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 2. 使用Bert进行命名实体识别\n", | |||||
"\n", | |||||
"命名实体识别是给定一句话,标记出其中的实体。一般序列标注的任务都使用conll格式,conll格式是至一行中通过制表符分隔不同的内容,使用空行分隔 两句话,例如下面的例子\n", | |||||
"\n", | |||||
"```\n", | |||||
" 中 B-ORG\n", | |||||
" 共 I-ORG\n", | |||||
" 中 I-ORG\n", | |||||
" 央 I-ORG\n", | |||||
" 致 O\n", | |||||
" 中 B-ORG\n", | |||||
" 国 I-ORG\n", | |||||
" 致 I-ORG\n", | |||||
" 公 I-ORG\n", | |||||
" 党 I-ORG\n", | |||||
" 十 I-ORG\n", | |||||
" 一 I-ORG\n", | |||||
" 大 I-ORG\n", | |||||
" 的 O\n", | |||||
" 贺 O\n", | |||||
" 词 O\n", | |||||
"```\n", | |||||
"\n", | |||||
"这部分内容请参考 快速实现序列标注模型\n", | |||||
"\n", | |||||
"## 3. 使用Bert进行文本匹配\n", | |||||
"\n", | |||||
"文本匹配任务是指给定两句话判断他们的关系。比如,给定两句话判断前一句是否和后一句具有因果关系或是否是矛盾关系;或者给定两句话判断两句话是否 具有相同的意思。这里我们使用" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.io import CNXNLIBertPipe\n", | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"from fastNLP.models import BertForSentenceMatching\n", | |||||
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n", | |||||
"from fastNLP.core.optimizer import AdamW\n", | |||||
"from fastNLP.core.callback import WarmupCallback\n", | |||||
"from fastNLP import Tester\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"data_bundle = CNXNLIBertPipe().process_from_file()\n", | |||||
"data_bundle.rename_field('chars', 'words')\n", | |||||
"print(data_bundle)\n", | |||||
"\n", | |||||
"# 载入BertEmbedding\n", | |||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn-wwm', include_cls_sep=True)\n", | |||||
"\n", | |||||
"# 载入模型\n", | |||||
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))\n", | |||||
"\n", | |||||
"# 训练模型\n", | |||||
"callbacks = [WarmupCallback(warmup=0.1, schedule='linear'), ]\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n", | |||||
"trainer = Trainer(data_bundle.get_dataset('train'), model,\n", | |||||
" optimizer=AdamW(params=model.parameters(), lr=4e-5),\n", | |||||
" loss=CrossEntropyLoss(), device=device,\n", | |||||
" batch_size=8, dev_data=data_bundle.get_dataset('dev'),\n", | |||||
" metrics=AccuracyMetric(), n_epochs=5, print_every=1,\n", | |||||
" update_every=8, callbacks=callbacks)\n", | |||||
"trainer.train()\n", | |||||
"\n", | |||||
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=8, metrics=AccuracyMetric())\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 4. 使用Bert进行中文问答\n", | |||||
"\n", | |||||
"问答任务是给定一段内容,以及一个问题,需要从这段内容中找到答案。 例如:\n", | |||||
"\n", | |||||
"```\n", | |||||
"\"context\": \"锣鼓经是大陆传统器乐及戏曲里面常用的打击乐记谱方法,以中文字的声音模拟敲击乐的声音,纪录打击乐的各种不同的演奏方法。常\n", | |||||
"用的节奏型称为「锣鼓点」。而锣鼓是戏曲节奏的支柱,除了加强演员身段动作的节奏感,也作为音乐的引子和尾声,提示音乐的板式和速度,以及\n", | |||||
"作为唱腔和念白的伴奏,令诗句的韵律更加抑扬顿锉,段落分明。锣鼓的运用有约定俗成的程式,依照角色行当的身份、性格、情绪以及环境,配合\n", | |||||
"相应的锣鼓点。锣鼓亦可以模仿大自然的音响效果,如雷电、波浪等等。戏曲锣鼓所运用的敲击乐器主要分为鼓、锣、钹和板四类型:鼓类包括有单\n", | |||||
"皮鼓(板鼓)、大鼓、大堂鼓(唐鼓)、小堂鼓、怀鼓、花盆鼓等;锣类有大锣、小锣(手锣)、钲锣、筛锣、马锣、镗锣、云锣;钹类有铙钹、大\n", | |||||
"钹、小钹、水钹、齐钹、镲钹、铰子、碰钟等;打拍子用的檀板、木鱼、梆子等。因为京剧的锣鼓通常由四位乐师负责,又称为四大件,领奏的师\n", | |||||
"傅称为:「鼓佬」,其职责有如西方乐队的指挥,负责控制速度以及利用各种手势提示乐师演奏不同的锣鼓点。粤剧吸收了部份京剧的锣鼓,但以木鱼\n", | |||||
"和沙的代替了京剧的板和鼓,作为打拍子的主要乐器。以下是京剧、昆剧和粤剧锣鼓中乐器对应的口诀用字:\",\n", | |||||
"\"question\": \"锣鼓经是什么?\",\n", | |||||
"\"answers\": [\n", | |||||
" {\n", | |||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n", | |||||
" \"answer_start\": 4\n", | |||||
" },\n", | |||||
" {\n", | |||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n", | |||||
" \"answer_start\": 4\n", | |||||
" },\n", | |||||
" {\n", | |||||
" \"text\": \"大陆传统器乐及戏曲里面常用的打击乐记谱方法\",\n", | |||||
" \"answer_start\": 4\n", | |||||
" }\n", | |||||
"]\n", | |||||
"```" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"您可以通过以下的代码训练 (原文代码:[CMRC2018](https://github.com/ymcui/cmrc2018) )" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"from fastNLP.models import BertForQuestionAnswering\n", | |||||
"from fastNLP.core.losses import CMRC2018Loss\n", | |||||
"from fastNLP.core.metrics import CMRC2018Metric\n", | |||||
"from fastNLP.io.pipe.qa import CMRC2018BertPipe\n", | |||||
"from fastNLP import Trainer, BucketSampler\n", | |||||
"from fastNLP import WarmupCallback, GradientClipCallback\n", | |||||
"from fastNLP.core.optimizer import AdamW\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"data_bundle = CMRC2018BertPipe().process_from_file()\n", | |||||
"data_bundle.rename_field('chars', 'words')\n", | |||||
"\n", | |||||
"print(data_bundle)\n", | |||||
"\n", | |||||
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='cn', requires_grad=True, include_cls_sep=False, auto_truncate=True,\n", | |||||
" dropout=0.5, word_dropout=0.01)\n", | |||||
"model = BertForQuestionAnswering(embed)\n", | |||||
"loss = CMRC2018Loss()\n", | |||||
"metric = CMRC2018Metric()\n", | |||||
"\n", | |||||
"wm_callback = WarmupCallback(schedule='linear')\n", | |||||
"gc_callback = GradientClipCallback(clip_value=1, clip_type='norm')\n", | |||||
"callbacks = [wm_callback, gc_callback]\n", | |||||
"\n", | |||||
"optimizer = AdamW(model.parameters(), lr=5e-5)\n", | |||||
"\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' \n", | |||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n", | |||||
" sampler=BucketSampler(seq_len_field_name='context_len'),\n", | |||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric,\n", | |||||
" callbacks=callbacks, device=device, batch_size=6, num_workers=2, n_epochs=2, print_every=1,\n", | |||||
" test_use_tqdm=False, update_every=10)\n", | |||||
"trainer.train(load_best_model=False)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"训练结果(和论文中报道的基本一致):\n", | |||||
"\n", | |||||
"```\n", | |||||
" In Epoch:2/Step:1692, got best dev performance:\n", | |||||
" CMRC2018Metric: f1=85.61, em=66.08\n", | |||||
"```" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,292 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# fastNLP中的DataSet" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+------------------------------+---------------------------------------------+---------+\n", | |||||
"| raw_words | words | seq_len |\n", | |||||
"+------------------------------+---------------------------------------------+---------+\n", | |||||
"| This is the first instance . | ['this', 'is', 'the', 'first', 'instance... | 6 |\n", | |||||
"| Second instance . | ['Second', 'instance', '.'] | 3 |\n", | |||||
"| Third instance . | ['Third', 'instance', '.'] | 3 |\n", | |||||
"+------------------------------+---------------------------------------------+---------+\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"],\n", | |||||
" 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.']],\n", | |||||
" 'seq_len': [6, 3, 3]}\n", | |||||
"dataset = DataSet(data)\n", | |||||
"# 传入的dict的每个key的value应该为具有相同长度的list\n", | |||||
"print(dataset)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## DataSet的构建" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"+----------------------------+---------------------------------------------+---------+\n", | |||||
"| raw_words | words | seq_len |\n", | |||||
"+----------------------------+---------------------------------------------+---------+\n", | |||||
"| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n", | |||||
"+----------------------------+---------------------------------------------+---------+" | |||||
] | |||||
}, | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"dataset = DataSet()\n", | |||||
"instance = Instance(raw_words=\"This is the first instance\",\n", | |||||
" words=['this', 'is', 'the', 'first', 'instance', '.'],\n", | |||||
" seq_len=6)\n", | |||||
"dataset.append(instance)\n", | |||||
"dataset" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"+----------------------------+---------------------------------------------+---------+\n", | |||||
"| raw_words | words | seq_len |\n", | |||||
"+----------------------------+---------------------------------------------+---------+\n", | |||||
"| This is the first instance | ['this', 'is', 'the', 'first', 'instance... | 6 |\n", | |||||
"| Second instance . | ['Second', 'instance', '.'] | 3 |\n", | |||||
"+----------------------------+---------------------------------------------+---------+" | |||||
] | |||||
}, | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"from fastNLP import Instance\n", | |||||
"dataset = DataSet([\n", | |||||
" Instance(raw_words=\"This is the first instance\",\n", | |||||
" words=['this', 'is', 'the', 'first', 'instance', '.'],\n", | |||||
" seq_len=6),\n", | |||||
" Instance(raw_words=\"Second instance .\",\n", | |||||
" words=['Second', 'instance', '.'],\n", | |||||
" seq_len=3)\n", | |||||
" ])\n", | |||||
"dataset" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## DataSet的删除" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"+----+---+\n", | |||||
"| a | c |\n", | |||||
"+----+---+\n", | |||||
"| -5 | 0 |\n", | |||||
"| -4 | 0 |\n", | |||||
"| -3 | 0 |\n", | |||||
"| -2 | 0 |\n", | |||||
"| -1 | 0 |\n", | |||||
"| 0 | 0 |\n", | |||||
"| 1 | 0 |\n", | |||||
"| 2 | 0 |\n", | |||||
"| 3 | 0 |\n", | |||||
"| 4 | 0 |\n", | |||||
"+----+---+" | |||||
] | |||||
}, | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"dataset = DataSet({'a': range(-5, 5), 'c': [0]*10})\n", | |||||
"dataset" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"+---+\n", | |||||
"| c |\n", | |||||
"+---+\n", | |||||
"| 0 |\n", | |||||
"| 0 |\n", | |||||
"| 0 |\n", | |||||
"| 0 |\n", | |||||
"+---+" | |||||
] | |||||
}, | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 不改变dataset,生成一个删除了满足条件的instance的新 DataSet\n", | |||||
"dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False)\n", | |||||
"# 在dataset中删除满足条件的instance\n", | |||||
"dataset.drop(lambda ins:ins['a']<0)\n", | |||||
"# 删除第3个instance\n", | |||||
"dataset.delete_instance(2)\n", | |||||
"# 删除名为'a'的field\n", | |||||
"dataset.delete_field('a')\n", | |||||
"dataset" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 简单的数据预处理" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"False\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"4" | |||||
] | |||||
}, | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 检查是否存在名为'a'的field\n", | |||||
"print(dataset.has_field('a')) # 或 ('a' in dataset)\n", | |||||
"# 将名为'a'的field改名为'b'\n", | |||||
"dataset.rename_field('c', 'b')\n", | |||||
"# DataSet的长度\n", | |||||
"len(dataset)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"+------------------------------+-------------------------------------------------+\n", | |||||
"| raw_words | words |\n", | |||||
"+------------------------------+-------------------------------------------------+\n", | |||||
"| This is the first instance . | ['This', 'is', 'the', 'first', 'instance', '.'] |\n", | |||||
"| Second instance . | ['Second', 'instance', '.'] |\n", | |||||
"| Third instance . | ['Third', 'instance', '.'] |\n", | |||||
"+------------------------------+-------------------------------------------------+" | |||||
] | |||||
}, | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import DataSet\n", | |||||
"data = {'raw_words':[\"This is the first instance .\", \"Second instance .\", \"Third instance .\"]}\n", | |||||
"dataset = DataSet(data)\n", | |||||
"\n", | |||||
"# 将句子分成单词形式, 详见DataSet.apply()方法\n", | |||||
"dataset.apply(lambda ins: ins['raw_words'].split(), new_field_name='words')\n", | |||||
"\n", | |||||
"# 或使用DataSet.apply_field()\n", | |||||
"dataset.apply_field(lambda sent:sent.split(), field_name='raw_words', new_field_name='words')\n", | |||||
"\n", | |||||
"# 除了匿名函数,也可以定义函数传递进去\n", | |||||
"def get_words(instance):\n", | |||||
" sentence = instance['raw_words']\n", | |||||
" words = sentence.split()\n", | |||||
" return words\n", | |||||
"dataset.apply(get_words, new_field_name='words')\n", | |||||
"dataset" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,343 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# fastNLP中的 Vocabulary\n", | |||||
"## 构建 Vocabulary" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(['复', '旦', '大', '学']) # 加入新的字\n", | |||||
"vocab.add_word('上海') # `上海`会作为一个整体\n", | |||||
"vocab.to_index('复') # 应该会为3\n", | |||||
"vocab.to_index('我') # 会输出1,Vocabulary中默认pad的index为0, unk(没有找到的词)的index为1\n", | |||||
"\n", | |||||
"# 在构建target的Vocabulary时,词表中应该用不上pad和unk,可以通过以下的初始化\n", | |||||
"vocab = Vocabulary(unknown=None, padding=None)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"Vocabulary(['positive', 'negative']...)" | |||||
] | |||||
}, | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab.add_word_lst(['positive', 'negative'])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": { | |||||
"scrolled": true | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"0" | |||||
] | |||||
}, | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab.to_index('positive')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 没有设置 unk 的情况" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": { | |||||
"scrolled": true | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"ename": "ValueError", | |||||
"evalue": "word `neutral` not in vocabulary", | |||||
"output_type": "error", | |||||
"traceback": [ | |||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |||||
"\u001b[0;32m<ipython-input-4-c6d424040b45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'neutral'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 会报错,因为没有unk这种情况\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36mto_index\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 414\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m \u001b[0mint\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mnumber\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \"\"\"\n\u001b[0;32m--> 416\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 417\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m_wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrebuild\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_vocab\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;32m~/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/core/vocabulary.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, w)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_word2idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munknown\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"word `{}` not in vocabulary\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_check_build_vocab\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||||
"\u001b[0;31mValueError\u001b[0m: word `neutral` not in vocabulary" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab.to_index('neutral') # 会报错,因为没有unk这种情况" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 设置 unk 的情况" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 25, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"(0, '<unk>')" | |||||
] | |||||
}, | |||||
"execution_count": 25, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary(unknown='<unk>', padding=None)\n", | |||||
"vocab.add_word_lst(['positive', 'negative'])\n", | |||||
"vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral'))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"Vocabulary(['positive', 'negative']...)" | |||||
] | |||||
}, | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+---------------------------------------------------+--------+\n", | |||||
"| chars | target |\n", | |||||
"+---------------------------------------------------+--------+\n", | |||||
"| [4, 2, 2, 5, 6, 7, 3] | 0 |\n", | |||||
"| [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 3] | 1 |\n", | |||||
"+---------------------------------------------------+--------+\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"from fastNLP import DataSet\n", | |||||
"\n", | |||||
"dataset = DataSet({'chars': [\n", | |||||
" ['今', '天', '天', '气', '很', '好', '。'],\n", | |||||
" ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n", | |||||
" ],\n", | |||||
" 'target': ['neutral', 'negative']\n", | |||||
"})\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.from_dataset(dataset, field_name='chars')\n", | |||||
"vocab.index_dataset(dataset, field_name='chars')\n", | |||||
"\n", | |||||
"target_vocab = Vocabulary(padding=None, unknown=None)\n", | |||||
"target_vocab.from_dataset(dataset, field_name='target')\n", | |||||
"target_vocab.index_dataset(dataset, field_name='target')\n", | |||||
"print(dataset)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"Vocabulary(['今', '天', '心', '情', '很']...)" | |||||
] | |||||
}, | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Vocabulary\n", | |||||
"from fastNLP import DataSet\n", | |||||
"\n", | |||||
"tr_data = DataSet({'chars': [\n", | |||||
" ['今', '天', '心', '情', '很', '好', '。'],\n", | |||||
" ['被', '这', '部', '电', '影', '浪', '费', '了', '两', '个', '小', '时', '。']\n", | |||||
" ],\n", | |||||
" 'target': ['positive', 'negative']\n", | |||||
"})\n", | |||||
"dev_data = DataSet({'chars': [\n", | |||||
" ['住', '宿', '条', '件', '还', '不', '错'],\n", | |||||
" ['糟', '糕', '的', '天', '气', ',', '无', '法', '出', '行', '。']\n", | |||||
" ],\n", | |||||
" 'target': ['positive', 'negative']\n", | |||||
"})\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"# 将验证集或者测试集在建立词表是放入no_create_entry_dataset这个参数中。\n", | |||||
"vocab.from_dataset(tr_data, field_name='chars', no_create_entry_dataset=[dev_data])\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
" 4%|▎ | 2.31M/63.5M [00:00<00:02, 22.9MB/s]" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"http://212.129.155.247/embedding/glove.6B.50d.zip not found in cache, downloading to /tmp/tmpvziobj_e\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"100%|██████████| 63.5M/63.5M [00:01<00:00, 41.3MB/s]\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Finish download from http://212.129.155.247/embedding/glove.6B.50d.zip\n", | |||||
"Copy file to /remote-home/ynzheng/.fastNLP/embedding/glove.6B.50d\n", | |||||
"Found 2 out of 6 words in the pre-training embedding.\n", | |||||
"tensor([[ 0.9497, 0.3433, 0.8450, -0.8852, -0.7208, -0.2931, -0.7468, 0.6512,\n", | |||||
" 0.4730, -0.7401, 0.1877, -0.3828, -0.5590, 0.4295, -0.2698, -0.4238,\n", | |||||
" -0.3124, 1.3423, -0.7857, -0.6302, 0.9182, 0.2113, -0.5744, 1.4549,\n", | |||||
" 0.7546, -1.6165, -0.0085, 0.0029, 0.5130, -0.4745, 2.5306, 0.8594,\n", | |||||
" -0.3067, 0.0578, 0.6623, 0.2080, 0.6424, -0.5246, -0.0534, 1.1404,\n", | |||||
" -0.1370, -0.1836, 0.4546, -0.5096, -0.0255, -0.0286, 0.1805, -0.4483,\n", | |||||
" 0.4053, -0.3682]], grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.1320, -0.2392, 0.1732, -0.2390, -0.0463, 0.0494, 0.0488, -0.0886,\n", | |||||
" 0.0224, -0.1300, 0.0369, 0.1800, 0.0750, -0.0183, 0.2264, 0.1628,\n", | |||||
" 0.1261, -0.1259, 0.1663, -0.1230, -0.1904, -0.0532, 0.1397, -0.0259,\n", | |||||
" -0.1799, 0.0226, 0.1858, 0.1981, 0.1338, 0.2394, 0.0248, 0.0203,\n", | |||||
" -0.1722, -0.1683, -0.1892, 0.0874, 0.0562, -0.0394, 0.0306, -0.1761,\n", | |||||
" 0.1015, -0.0171, 0.1172, 0.1357, 0.1519, -0.0011, 0.1572, 0.1265,\n", | |||||
" -0.2391, -0.0258]], grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.1318, -0.2552, -0.0679, 0.2619, -0.2616, 0.2357, 0.1308, -0.0118,\n", | |||||
" 1.7659, 0.2078, 0.2620, -0.1643, -0.8464, 0.0201, 0.0702, 0.3978,\n", | |||||
" 0.1528, -0.2021, -1.6184, -0.5433, -0.1786, 0.5389, 0.4987, -0.1017,\n", | |||||
" 0.6626, -1.7051, 0.0572, -0.3241, -0.6683, 0.2665, 2.8420, 0.2684,\n", | |||||
" -0.5954, -0.5004, 1.5199, 0.0396, 1.6659, 0.9976, -0.5597, -0.7049,\n", | |||||
" -0.0309, -0.2830, -0.1356, 0.6429, 0.4149, 1.2362, 0.7659, 0.9780,\n", | |||||
" 0.5851, -0.3018]], grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||||
" 0., 0.]], grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||||
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", | |||||
" 0., 0.]], grad_fn=<EmbeddingBackward>)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"import torch\n", | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word('train')\n", | |||||
"vocab.add_word('only_in_train') # 仅在train出现,但肯定在预训练词表中不存在\n", | |||||
"vocab.add_word('test', no_create_entry=True) # 该词只在dev或test中出现\n", | |||||
"vocab.add_word('only_in_test', no_create_entry=True) # 这个词在预训练的词表中找不到\n", | |||||
"\n", | |||||
"embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('train')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('only_in_train')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('test')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('only_in_test')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,524 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Found 5 out of 7 words in the pre-training embedding.\n", | |||||
"torch.Size([1, 5, 50])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"import torch\n", | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n", | |||||
"\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]]) # 将文本转为index\n", | |||||
"print(embed(words).size()) # StaticEmbedding的使用和pytorch的nn.Embedding是类似的" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"torch.Size([1, 5, 30])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30)\n", | |||||
"\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"22 out of 22 characters were found in pretrained elmo embedding.\n", | |||||
"torch.Size([1, 5, 256])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import ElmoEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False)\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"22 out of 22 characters were found in pretrained elmo embedding.\n", | |||||
"torch.Size([1, 5, 512])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2')\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"22 out of 22 characters were found in pretrained elmo embedding.\n", | |||||
"torch.Size([1, 5, 256])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix')\n", | |||||
"print(embed(words).size()) # 三层输出按照权重element-wise的加起来" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"torch.Size([1, 5, 768])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased')\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"torch.Size([1, 5, 1536])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 使用后面两层的输出\n", | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11')\n", | |||||
"print(embed(words).size()) # 结果将是在最后一维做拼接" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"torch.Size([1, 7, 768])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True)\n", | |||||
"print(embed(words).size()) # 结果将在序列维度上增加2\n", | |||||
"# 取出句子的cls表示\n", | |||||
"cls_reps = embed(words)[:, 0] # shape: [batch_size, 768]" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"torch.Size([1, 5, 768])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 10 words out of 10.\n", | |||||
"torch.Size([1, 9, 768])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo . [SEP] another sentence .\".split())\n", | |||||
"\n", | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo . [SEP] another sentence .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 11, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Start constructing character vocabulary.\n", | |||||
"In total, there are 8 distinct characters.\n", | |||||
"torch.Size([1, 5, 64])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import CNNCharEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n", | |||||
"embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Start constructing character vocabulary.\n", | |||||
"In total, there are 8 distinct characters.\n", | |||||
"torch.Size([1, 5, 64])\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import LSTMCharEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"# character的embedding维度大小为50,返回的embedding结果维度大小为64。\n", | |||||
"embed = LSTMCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n", | |||||
"words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n", | |||||
"print(embed(words).size())" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 13, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Found 5 out of 7 words in the pre-training embedding.\n", | |||||
"50\n", | |||||
"Start constructing character vocabulary.\n", | |||||
"In total, there are 8 distinct characters.\n", | |||||
"30\n", | |||||
"22 out of 22 characters were found in pretrained elmo embedding.\n", | |||||
"256\n", | |||||
"22 out of 22 characters were found in pretrained elmo embedding.\n", | |||||
"512\n", | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"768\n", | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n", | |||||
"1536\n", | |||||
"80\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import *\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n", | |||||
"print(static_embed.embedding_dim) # 50\n", | |||||
"char_embed = CNNCharEmbedding(vocab, embed_size=30)\n", | |||||
"print(char_embed.embedding_dim) # 30\n", | |||||
"elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2')\n", | |||||
"print(elmo_embed_1.embedding_dim) # 256\n", | |||||
"elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2')\n", | |||||
"print(elmo_embed_2.embedding_dim) # 512\n", | |||||
"bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased')\n", | |||||
"print(bert_embed_1.embedding_dim) # 768\n", | |||||
"bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased')\n", | |||||
"print(bert_embed_2.embedding_dim) # 1536\n", | |||||
"stack_embed = StackEmbedding([static_embed, char_embed])\n", | |||||
"print(stack_embed.embedding_dim) # 80" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 14, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 7 words out of 7.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import *\n", | |||||
"\n", | |||||
"vocab = Vocabulary()\n", | |||||
"vocab.add_word_lst(\"this is a demo .\".split())\n", | |||||
"\n", | |||||
"embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', requires_grad=True) # 初始化时设定为需要更新\n", | |||||
"embed.requires_grad = False # 修改BertEmbedding的权重为不更新" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 15, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"tensor([[ 0.3633, -0.2091, -0.0353, -0.3771, -0.5193]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.0926, -0.4812, -0.7744, 0.4836, -0.5475]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n", | |||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练词向量时效果是一致的\n", | |||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5)\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('The')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 16, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"All word in the vocab have been lowered. There are 6 words, 4 unique lowered words.\n", | |||||
"tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.4530, -0.1558, -0.1941, 0.3203, 0.0355]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n", | |||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n", | |||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True)\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('The')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 17, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"1 out of 4 words have frequency less than 2.\n", | |||||
"tensor([[ 0.4724, -0.7277, -0.6350, -0.5258, -0.6063]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.7638, -0.0552, 0.1625, -0.2210, 0.4993]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary().add_word_lst(\"the the the a\".split())\n", | |||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n", | |||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2)\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('a')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 18, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"0 out of 5 words have frequency less than 2.\n", | |||||
"All word in the vocab have been lowered. There are 5 words, 4 unique lowered words.\n", | |||||
"tensor([[ 0.1943, 0.3739, 0.2769, -0.4746, -0.3181]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[ 0.5892, -0.6916, 0.7319, -0.3803, 0.4979]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n", | |||||
"tensor([[-0.1348, -0.2172, -0.0071, 0.5704, -0.2607]],\n", | |||||
" grad_fn=<EmbeddingBackward>)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"from fastNLP import Vocabulary\n", | |||||
"\n", | |||||
"vocab = Vocabulary().add_word_lst(\"the the the a A\".split())\n", | |||||
"# 下面用随机的StaticEmbedding演示,但与使用预训练时效果是一致的\n", | |||||
"embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True)\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('the')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('a')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.to_index('A')])))\n", | |||||
"print(embed(torch.LongTensor([vocab.unknown_idx])))" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,309 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# 使用Loader和Pipe加载并处理数据集\n", | |||||
"\n", | |||||
"这一部分是关于如何加载数据集的教程\n", | |||||
"\n", | |||||
"## Part I: 数据集容器DataBundle\n", | |||||
"\n", | |||||
"而由于对于同一个任务,训练集,验证集和测试集会共用同一个词表以及具有相同的目标值,所以在fastNLP中我们使用了 DataBundle 来承载同一个任务的多个数据集 DataSet 以及它们的词表 Vocabulary 。下面会有例子介绍 DataBundle 的相关使用。\n", | |||||
"\n", | |||||
"DataBundle 在fastNLP中主要在各个 Loader 和 Pipe 中被使用。 下面我们先介绍一下 Loader 和 Pipe 。\n", | |||||
"\n", | |||||
"## Part II: 加载的各种数据集的Loader\n", | |||||
"\n", | |||||
"在fastNLP中,所有的 Loader 都可以通过其文档判断其支持读取的数据格式,以及读取之后返回的 DataSet 的格式, 例如 ChnSentiCorpLoader \n", | |||||
"\n", | |||||
"- download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。\n", | |||||
"\n", | |||||
"- _load() 函数:从一个数据文件中读取数据,返回一个 DataSet 。返回的DataSet的格式可从Loader文档判断。\n", | |||||
"\n", | |||||
"- load() 函数:从文件或者文件夹中读取数据为 DataSet 并将它们组装成 DataBundle。支持接受的参数类型有以下的几种\n", | |||||
"\n", | |||||
" - None, 将尝试读取自动缓存的数据,仅支持提供了自动下载数据的Loader\n", | |||||
" - 文件夹路径, 默认将尝试在该文件夹下匹配文件名中含有 train , test , dev 的文件,如果有多个文件含有相同的关键字,将无法通过该方式读取\n", | |||||
" - dict, 例如{'train':\"/path/to/tr.conll\", 'dev':\"/to/validate.conll\", \"test\":\"/to/te.conll\"}。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\ttest has 1944 instances.\n", | |||||
"\ttrain has 17196 instances.\n", | |||||
"\tdev has 1858 instances.\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.io import CWSLoader\n", | |||||
"\n", | |||||
"loader = CWSLoader(dataset_name='pku')\n", | |||||
"data_bundle = loader.load()\n", | |||||
"print(data_bundle)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"这里表示一共有3个数据集。其中:\n", | |||||
"\n", | |||||
" 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n", | |||||
"\n", | |||||
"也可以取出DataSet,并打印DataSet中的具体内容" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+----------------------------------------------------------------+\n", | |||||
"| raw_words |\n", | |||||
"+----------------------------------------------------------------+\n", | |||||
"| 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ... |\n", | |||||
"| 中共中央 总书记 、 国家 主席 江 泽民 |\n", | |||||
"+----------------------------------------------------------------+\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"tr_data = data_bundle.get_dataset('train')\n", | |||||
"print(tr_data[:2])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Part III: 使用Pipe对数据集进行预处理\n", | |||||
"\n", | |||||
"通过 Loader 可以将文本数据读入,但并不能直接被神经网络使用,还需要进行一定的预处理。\n", | |||||
"\n", | |||||
"在fastNLP中,我们使用 Pipe 的子类作为数据预处理的类, Loader 和 Pipe 一般具备一一对应的关系,该关系可以从其名称判断, 例如 CWSLoader 与 CWSPipe 是一一对应的。一般情况下Pipe处理包含以下的几个过程,\n", | |||||
"1. 将raw_words或 raw_chars进行tokenize以切分成不同的词或字; \n", | |||||
"2. 再建立词或字的 Vocabulary , 并将词或字转换为index; \n", | |||||
"3. 将target 列建立词表并将target列转为index;\n", | |||||
"\n", | |||||
"所有的Pipe都可通过其文档查看该Pipe支持处理的 DataSet 以及返回的 DataBundle 中的Vocabulary的情况; 如 OntoNotesNERPipe\n", | |||||
"\n", | |||||
"各种数据集的Pipe当中,都包含了以下的两个函数:\n", | |||||
"\n", | |||||
"- process() 函数:对输入的 DataBundle 进行处理, 然后返回处理之后的 DataBundle 。process函数的文档中包含了该Pipe支持处理的DataSet的格式。\n", | |||||
"- process_from_file() 函数:输入数据集所在文件夹,使用对应的Loader读取数据(所以该函数支持的参数类型是由于其对应的Loader的load函数决定的),然后调用相对应的process函数对数据进行预处理。相当于是把Load和process放在一个函数中执行。\n", | |||||
"\n", | |||||
"接着上面 CWSLoader 的例子,我们展示一下 CWSPipe 的功能:" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\ttest has 1944 instances.\n", | |||||
"\ttrain has 17196 instances.\n", | |||||
"\tdev has 1858 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\tchars has 4777 entries.\n", | |||||
"\ttarget has 4 entries.\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.io import CWSPipe\n", | |||||
"\n", | |||||
"data_bundle = CWSPipe().process(data_bundle)\n", | |||||
"print(data_bundle)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"表示一共有3个数据集和2个词表。其中:\n", | |||||
"\n", | |||||
"- 3个数据集的名称分别为train、dev、test,分别有17223、1831、1944个instance\n", | |||||
"- 2个词表分别为chars词表与target词表。其中chars词表为句子文本所构建的词表,一共有4777个不同的字;target词表为目标标签所构建的词表,一共有4种标签。\n", | |||||
"\n", | |||||
"相较于之前CWSLoader读取的DataBundle,新增了两个Vocabulary。 我们可以打印一下处理之后的DataSet" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+---------------------+---------------------+---------------------+---------+\n", | |||||
"| raw_words | chars | target | seq_len |\n", | |||||
"+---------------------+---------------------+---------------------+---------+\n", | |||||
"| 迈向 充满 希望... | [1224, 178, 674,... | [0, 1, 0, 1, 0, ... | 29 |\n", | |||||
"| 中共中央 总书记... | [11, 212, 11, 33... | [0, 3, 3, 1, 0, ... | 15 |\n", | |||||
"+---------------------+---------------------+---------------------+---------+\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"tr_data = data_bundle.get_dataset('train')\n", | |||||
"print(tr_data[:2])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"可以看到有两列为int的field: chars和target。这两列的名称同时也是DataBundle中的Vocabulary的名称。可以通过下列的代码获取并查看Vocabulary的 信息" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Vocabulary(['B', 'E', 'S', 'M']...)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"vocab = data_bundle.get_vocab('target')\n", | |||||
"print(vocab)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## Part IV: fastNLP封装好的Loader和Pipe\n", | |||||
"\n", | |||||
"fastNLP封装了多种任务/数据集的 Loader 和 Pipe 并提供自动下载功能,具体参见文档 [数据集](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)\n", | |||||
"\n", | |||||
"## Part V: 不同格式类型的基础Loader\n", | |||||
"\n", | |||||
"除了上面提到的针对具体任务的Loader,我们还提供了CSV格式和JSON格式的Loader\n", | |||||
"\n", | |||||
"**CSVLoader** 读取CSV类型的数据集文件。例子如下:\n", | |||||
"\n", | |||||
"```python\n", | |||||
"from fastNLP.io.loader import CSVLoader\n", | |||||
"data_set_loader = CSVLoader(\n", | |||||
" headers=('raw_words', 'target'), sep='\\t'\n", | |||||
")\n", | |||||
"```\n", | |||||
"\n", | |||||
"表示将CSV文件中每一行的第一项将填入'raw_words' field,第二项填入'target' field。其中项之间由'\\t'分割开来\n", | |||||
"\n", | |||||
"```python\n", | |||||
"data_set = data_set_loader._load('path/to/your/file')\n", | |||||
"```\n", | |||||
"\n", | |||||
"文件内容样例如下\n", | |||||
"\n", | |||||
"```csv\n", | |||||
"But it does not leave you with much . 1\n", | |||||
"You could hate it for the same reason . 1\n", | |||||
"The performances are an absolute joy . 4\n", | |||||
"```\n", | |||||
"\n", | |||||
"读取之后的DataSet具有以下的field\n", | |||||
"\n", | |||||
"| raw_words | target |\n", | |||||
"| --------------------------------------- | ------ |\n", | |||||
"| But it does not leave you with much . | 1 |\n", | |||||
"| You could hate it for the same reason . | 1 |\n", | |||||
"| The performances are an absolute joy . | 4 |\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"**JsonLoader** 读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下\n", | |||||
"\n", | |||||
"```python\n", | |||||
"from fastNLP.io.loader import JsonLoader\n", | |||||
"loader = JsonLoader(\n", | |||||
" fields={'sentence1': 'raw_words1', 'sentence2': 'raw_words2', 'gold_label': 'target'}\n", | |||||
")\n", | |||||
"```\n", | |||||
"\n", | |||||
"表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'raw_words1'、'raw_words2'、'target'这三个fields\n", | |||||
"\n", | |||||
"```python\n", | |||||
"data_set = loader._load('path/to/your/file')\n", | |||||
"```\n", | |||||
"\n", | |||||
"数据集内容样例如下\n", | |||||
"```\n", | |||||
"{\"annotator_labels\": [\"neutral\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"neutral\", ... }\n", | |||||
"{\"annotator_labels\": [\"contradiction\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"contradiction\", ... }\n", | |||||
"{\"annotator_labels\": [\"entailment\"], \"captionID\": \"3416050480.jpg#4\", \"gold_label\": \"entailment\", ... }\n", | |||||
"```\n", | |||||
"\n", | |||||
"读取之后的DataSet具有以下的field\n", | |||||
"\n", | |||||
"| raw_words0 | raw_words1 | target |\n", | |||||
"| ------------------------------------------------------ | ------------------------------------------------- | ------------- |\n", | |||||
"| A person on a horse jumps over a broken down airplane. | A person is training his horse for a competition. | neutral |\n", | |||||
"| A person on a horse jumps over a broken down airplane. | A person is at a diner, ordering an omelette. | contradiction |\n", | |||||
"| A person on a horse jumps over a broken down airplane. | A person is outdoors, on a horse. | entailment |" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,603 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# 使用Trainer和Tester快速训练和测试" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 数据读入和处理" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\ttest has 1821 instances.\n", | |||||
"\ttrain has 67349 instances.\n", | |||||
"\tdev has 872 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\twords has 16292 entries.\n", | |||||
"\ttarget has 2 entries.\n", | |||||
"\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"| raw_words | target | words | seq_len |\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.io import SST2Pipe\n", | |||||
"\n", | |||||
"pipe = SST2Pipe()\n", | |||||
"databundle = pipe.process_from_file()\n", | |||||
"vocab = databundle.get_vocab('words')\n", | |||||
"print(databundle)\n", | |||||
"print(databundle.get_dataset('train')[0])\n", | |||||
"print(databundle.get_vocab('words'))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"4925 872 75\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"train_data = databundle.get_dataset('train')[:5000]\n", | |||||
"train_data, test_data = train_data.split(0.015)\n", | |||||
"dev_data = databundle.get_dataset('dev')\n", | |||||
"print(len(train_data),len(dev_data),len(test_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": { | |||||
"scrolled": false | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+-------------+-----------+--------+-------+---------+\n", | |||||
"| field_names | raw_words | target | words | seq_len |\n", | |||||
"+-------------+-----------+--------+-------+---------+\n", | |||||
"| is_input | False | False | True | True |\n", | |||||
"| is_target | False | True | False | False |\n", | |||||
"| ignore_type | | False | False | False |\n", | |||||
"| pad_value | | 0 | 0 | 0 |\n", | |||||
"+-------------+-----------+--------+-------+---------+\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"<prettytable.PrettyTable at 0x7f49ec540160>" | |||||
] | |||||
}, | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"train_data.print_field_meta()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 使用内置模型训练" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.models import CNNText\n", | |||||
"\n", | |||||
"#词嵌入的维度\n", | |||||
"EMBED_DIM = 100\n", | |||||
"\n", | |||||
"#使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数\n", | |||||
"#还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值\n", | |||||
"model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=2, dropout=0.1)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import AccuracyMetric\n", | |||||
"from fastNLP import Const\n", | |||||
"\n", | |||||
"# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n", | |||||
"metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"\n", | |||||
"# loss = CrossEntropyLoss() 在本例中与下面这行代码等价\n", | |||||
"loss = CrossEntropyLoss(pred=Const.OUTPUT, target=Const.TARGET)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field\n", | |||||
"# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数\n", | |||||
"# 传入func作为一个名为`target`的参数\n", | |||||
"#下面自己构建了一个交叉熵函数,和之后直接使用fastNLP中的交叉熵函数是一个效果\n", | |||||
"import torch\n", | |||||
"from fastNLP import LossFunc\n", | |||||
"func = torch.nn.functional.cross_entropy\n", | |||||
"loss_func = LossFunc(func, input=Const.OUTPUT, target=Const.TARGET)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import torch.optim as optim\n", | |||||
"\n", | |||||
"#使用 torch.optim 定义优化器\n", | |||||
"optimizer=optim.RMSprop(model_cnn.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 4]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2020-02-27-11-31-25\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=3080.0), HTML(value='')), layout=Layout(d…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.75 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/10. Step:308/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.751147\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.83 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/10. Step:616/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.755734\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 1.32 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/10. Step:924/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.758028\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.88 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 4/10. Step:1232/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.741972\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.96 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 5/10. Step:1540/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.728211\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.87 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 6/10. Step:1848/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.755734\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 1.04 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 7/10. Step:2156/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.732798\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.57 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 8/10. Step:2464/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.747706\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.48 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 9/10. Step:2772/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.732798\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.48 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 10/10. Step:3080/3080: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.740826\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:3/Step:924, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.758028\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'best_eval': {'AccuracyMetric': {'acc': 0.758028}},\n", | |||||
" 'best_epoch': 3,\n", | |||||
" 'best_step': 924,\n", | |||||
" 'seconds': 160.58}" | |||||
] | |||||
}, | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"\n", | |||||
"#训练的轮数和batch size\n", | |||||
"N_EPOCHS = 10\n", | |||||
"BATCH_SIZE = 16\n", | |||||
"\n", | |||||
"#如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3\n", | |||||
"#这里只使用了loss作为损失函数输入,感兴趣可以尝试其他损失函数(如之前自定义的loss_func)作为输入\n", | |||||
"trainer = Trainer(model=model_cnn, train_data=train_data, dev_data=dev_data, loss=loss, metrics=metrics,\n", | |||||
"optimizer=optimizer,n_epochs=N_EPOCHS, batch_size=BATCH_SIZE)\n", | |||||
"trainer.train()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.773333\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.773333}}" | |||||
] | |||||
}, | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Tester\n", | |||||
"\n", | |||||
"tester = Tester(test_data, model_cnn, metrics=AccuracyMetric())\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,681 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# 使用Trainer和Tester快速训练和测试" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 数据读入和处理" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stderr", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"/remote-home/ynzheng/anaconda3/envs/now/lib/python3.8/site-packages/FastNLP-0.5.0-py3.8.egg/fastNLP/io/loader/classification.py:340: UserWarning: SST2's test file has no target.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\ttest has 1821 instances.\n", | |||||
"\ttrain has 67349 instances.\n", | |||||
"\tdev has 872 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\twords has 16292 entries.\n", | |||||
"\ttarget has 2 entries.\n", | |||||
"\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"| raw_words | target | words | seq_len |\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"| hide new secretions from the p... | 1 | [4110, 97, 12009, 39, 2, 6843,... | 7 |\n", | |||||
"+-----------------------------------+--------+-----------------------------------+---------+\n", | |||||
"Vocabulary(['hide', 'new', 'secretions', 'from', 'the']...)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.io import SST2Pipe\n", | |||||
"\n", | |||||
"pipe = SST2Pipe()\n", | |||||
"databundle = pipe.process_from_file()\n", | |||||
"vocab = databundle.get_vocab('words')\n", | |||||
"print(databundle)\n", | |||||
"print(databundle.get_dataset('train')[0])\n", | |||||
"print(databundle.get_vocab('words'))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"4925 872 75\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"train_data = databundle.get_dataset('train')[:5000]\n", | |||||
"train_data, test_data = train_data.split(0.015)\n", | |||||
"dev_data = databundle.get_dataset('dev')\n", | |||||
"print(len(train_data),len(dev_data),len(test_data))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": { | |||||
"scrolled": false | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+-------------+-----------+--------+-------+---------+\n", | |||||
"| field_names | raw_words | target | words | seq_len |\n", | |||||
"+-------------+-----------+--------+-------+---------+\n", | |||||
"| is_input | False | False | True | True |\n", | |||||
"| is_target | False | True | False | False |\n", | |||||
"| ignore_type | | False | False | False |\n", | |||||
"| pad_value | | 0 | 0 | 0 |\n", | |||||
"+-------------+-----------+--------+-------+---------+\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"<prettytable.PrettyTable at 0x7f0db03d0640>" | |||||
] | |||||
}, | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"train_data.print_field_meta()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import AccuracyMetric\n", | |||||
"from fastNLP import Const\n", | |||||
"\n", | |||||
"# metrics=AccuracyMetric() 在本例中与下面这行代码等价\n", | |||||
"metrics=AccuracyMetric(pred=Const.OUTPUT, target=Const.TARGET)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## DataSetIter初探" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n", | |||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n", | |||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n", | |||||
" 1323, 4398, 7],\n", | |||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n", | |||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n", | |||||
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n", | |||||
"batch_y: {'target': tensor([1, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n", | |||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n", | |||||
"batch_y: {'target': tensor([0, 1])}\n", | |||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n", | |||||
" [15618, 3204, 5, 1675, 0]]), 'seq_len': tensor([5, 4])}\n", | |||||
"batch_y: {'target': tensor([1, 1])}\n", | |||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n", | |||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n", | |||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n", | |||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n", | |||||
"batch_y: {'target': tensor([0, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n", | |||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n", | |||||
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n", | |||||
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 12])}\n", | |||||
"batch_y: {'target': tensor([0, 1])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import BucketSampler\n", | |||||
"from fastNLP import DataSetIter\n", | |||||
"\n", | |||||
"tmp_data = dev_data[:10]\n", | |||||
"# 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n", | |||||
"# 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n", | |||||
"sampler = BucketSampler(batch_size=2, seq_len_field_name='seq_len')\n", | |||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n", | |||||
"for batch_x, batch_y in batch:\n", | |||||
" print(\"batch_x: \",batch_x)\n", | |||||
" print(\"batch_y: \", batch_y)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n", | |||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n", | |||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n", | |||||
" 1323, 4398, 7],\n", | |||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n", | |||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n", | |||||
" 7, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n", | |||||
" -1, -1, -1]]), 'seq_len': tensor([33, 21])}\n", | |||||
"batch_y: {'target': tensor([1, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7],\n", | |||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7]]), 'seq_len': tensor([9, 9])}\n", | |||||
"batch_y: {'target': tensor([0, 1])}\n", | |||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n", | |||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7],\n", | |||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n", | |||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7]]), 'seq_len': tensor([20, 20])}\n", | |||||
"batch_y: {'target': tensor([0, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7],\n", | |||||
" [15618, 3204, 5, 1675, -1]]), 'seq_len': tensor([5, 4])}\n", | |||||
"batch_y: {'target': tensor([1, 1])}\n", | |||||
"batch_x: {'words': tensor([[ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n", | |||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7],\n", | |||||
" [ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n", | |||||
" 1217, 7, -1, -1, -1, -1, -1, -1, -1, -1]]), 'seq_len': tensor([20, 12])}\n", | |||||
"batch_y: {'target': tensor([0, 1])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"tmp_data.set_pad_val('words',-1)\n", | |||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n", | |||||
"for batch_x, batch_y in batch:\n", | |||||
" print(\"batch_x: \",batch_x)\n", | |||||
" print(\"batch_y: \", batch_y)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"batch_x: {'words': tensor([[ 45, 752, 327, 180, 10, 15621, 16, 72, 8904, 9,\n", | |||||
" 1217, 7, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||||
" [ 879, 96, 8, 1026, 12, 8067, 11, 13623, 8, 15619,\n", | |||||
" 4, 673, 662, 15, 4, 1154, 240, 639, 417, 7,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([12, 20])}\n", | |||||
"batch_y: {'target': tensor([1, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 13, 830, 7746, 174, 3, 47, 6, 83, 5752, 15,\n", | |||||
" 2177, 15, 63, 57, 406, 84, 1009, 4973, 27, 17,\n", | |||||
" 13785, 3, 533, 3687, 15623, 39, 375, 8, 15624, 8,\n", | |||||
" 1323, 4398, 7, 0, 0, 0, 0, 0, 0, 0],\n", | |||||
" [ 1045, 11113, 16, 104, 5, 4, 176, 1824, 1704, 3,\n", | |||||
" 2, 18, 11, 4, 1018, 432, 143, 33, 245, 308,\n", | |||||
" 7, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([33, 21])}\n", | |||||
"batch_y: {'target': tensor([1, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 14, 10, 4, 311, 5, 154, 1418, 609, 7, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0],\n", | |||||
" [ 14, 10, 437, 32, 78, 3, 78, 437, 7, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0]]), 'seq_len': tensor([9, 9])}\n", | |||||
"batch_y: {'target': tensor([0, 1])}\n", | |||||
"batch_x: {'words': tensor([[ 2, 155, 3, 4426, 3, 239, 3, 739, 5, 1136,\n", | |||||
" 41, 43, 2427, 736, 2, 648, 10, 15620, 2285, 7,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||||
" [ 24, 95, 28, 46, 8, 336, 38, 239, 8, 2133,\n", | |||||
" 2, 18, 10, 15622, 1421, 6, 61, 5, 387, 7,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([20, 20])}\n", | |||||
"batch_y: {'target': tensor([0, 0])}\n", | |||||
"batch_x: {'words': tensor([[ 4, 277, 685, 18, 7, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||||
" [15618, 3204, 5, 1675, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'seq_len': tensor([5, 4])}\n", | |||||
"batch_y: {'target': tensor([1, 1])}\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.core.field import Padder\n", | |||||
"import numpy as np\n", | |||||
"class FixLengthPadder(Padder):\n", | |||||
" def __init__(self, pad_val=0, length=None):\n", | |||||
" super().__init__(pad_val=pad_val)\n", | |||||
" self.length = length\n", | |||||
" assert self.length is not None, \"Creating FixLengthPadder with no specific length!\"\n", | |||||
"\n", | |||||
" def __call__(self, contents, field_name, field_ele_dtype, dim):\n", | |||||
" #计算当前contents中的最大长度\n", | |||||
" max_len = max(map(len, contents))\n", | |||||
" #如果当前contents中的最大长度大于指定的padder length的话就报错\n", | |||||
" assert max_len <= self.length, \"Fixed padder length smaller than actual length! with length {}\".format(max_len)\n", | |||||
" array = np.full((len(contents), self.length), self.pad_val, dtype=field_ele_dtype)\n", | |||||
" for i, content_i in enumerate(contents):\n", | |||||
" array[i, :len(content_i)] = content_i\n", | |||||
" return array\n", | |||||
"\n", | |||||
"#设定FixLengthPadder的固定长度为40\n", | |||||
"tmp_padder = FixLengthPadder(pad_val=0,length=40)\n", | |||||
"#利用dataset的set_padder函数设定words field的padder\n", | |||||
"tmp_data.set_padder('words',tmp_padder)\n", | |||||
"batch = DataSetIter(batch_size=2, dataset=tmp_data, sampler=sampler)\n", | |||||
"for batch_x, batch_y in batch:\n", | |||||
" print(\"batch_x: \",batch_x)\n", | |||||
" print(\"batch_y: \", batch_y)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 使用DataSetIter自己编写训练过程\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"-----start training-----\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 2.68 seconds!\n", | |||||
"Epoch 0 Avg Loss: 0.66 AccuracyMetric: acc=0.708716 29307ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.38 seconds!\n", | |||||
"Epoch 1 Avg Loss: 0.41 AccuracyMetric: acc=0.770642 52200ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.51 seconds!\n", | |||||
"Epoch 2 Avg Loss: 0.16 AccuracyMetric: acc=0.747706 70268ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.96 seconds!\n", | |||||
"Epoch 3 Avg Loss: 0.06 AccuracyMetric: acc=0.741972 90349ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 1.04 seconds!\n", | |||||
"Epoch 4 Avg Loss: 0.03 AccuracyMetric: acc=0.740826 114250ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.8 seconds!\n", | |||||
"Epoch 5 Avg Loss: 0.02 AccuracyMetric: acc=0.738532 134742ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.65 seconds!\n", | |||||
"Epoch 6 Avg Loss: 0.01 AccuracyMetric: acc=0.731651 154503ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.8 seconds!\n", | |||||
"Epoch 7 Avg Loss: 0.01 AccuracyMetric: acc=0.738532 175397ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.36 seconds!\n", | |||||
"Epoch 8 Avg Loss: 0.01 AccuracyMetric: acc=0.733945 192384ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=55.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.84 seconds!\n", | |||||
"Epoch 9 Avg Loss: 0.01 AccuracyMetric: acc=0.744266 214417ms\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.04 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.786667\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.786667}}" | |||||
] | |||||
}, | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import BucketSampler\n", | |||||
"from fastNLP import DataSetIter\n", | |||||
"from fastNLP.models import CNNText\n", | |||||
"from fastNLP import Tester\n", | |||||
"import torch\n", | |||||
"import time\n", | |||||
"\n", | |||||
"embed_dim = 100\n", | |||||
"model = CNNText((len(vocab),embed_dim), num_classes=2, dropout=0.1)\n", | |||||
"\n", | |||||
"def train(epoch, data, devdata):\n", | |||||
" optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", | |||||
" lossfunc = torch.nn.CrossEntropyLoss()\n", | |||||
" batch_size = 32\n", | |||||
"\n", | |||||
" # 定义一个Batch,传入DataSet,规定batch_size和去batch的规则。\n", | |||||
" # 顺序(Sequential),随机(Random),相似长度组成一个batch(Bucket)\n", | |||||
" train_sampler = BucketSampler(batch_size=batch_size, seq_len_field_name='seq_len')\n", | |||||
" train_batch = DataSetIter(batch_size=batch_size, dataset=data, sampler=train_sampler)\n", | |||||
"\n", | |||||
" start_time = time.time()\n", | |||||
" print(\"-\"*5+\"start training\"+\"-\"*5)\n", | |||||
" for i in range(epoch):\n", | |||||
" loss_list = []\n", | |||||
" for batch_x, batch_y in train_batch:\n", | |||||
" optimizer.zero_grad()\n", | |||||
" output = model(batch_x['words'])\n", | |||||
" loss = lossfunc(output['pred'], batch_y['target'])\n", | |||||
" loss.backward()\n", | |||||
" optimizer.step()\n", | |||||
" loss_list.append(loss.item())\n", | |||||
"\n", | |||||
" #这里verbose如果为0,在调用Tester对象的test()函数时不输出任何信息,返回评估信息; 如果为1,打印出验证结果,返回评估信息\n", | |||||
" #在调用过Tester对象的test()函数后,调用其_format_eval_results(res)函数,结构化输出验证结果\n", | |||||
" tester_tmp = Tester(devdata, model, metrics=AccuracyMetric(), verbose=0)\n", | |||||
" res=tester_tmp.test()\n", | |||||
"\n", | |||||
" print('Epoch {:d} Avg Loss: {:.2f}'.format(i, sum(loss_list) / len(loss_list)),end=\" \")\n", | |||||
" print(tester_tmp._format_eval_results(res),end=\" \")\n", | |||||
" print('{:d}ms'.format(round((time.time()-start_time)*1000)))\n", | |||||
" loss_list.clear()\n", | |||||
"\n", | |||||
"train(10, train_data, dev_data)\n", | |||||
"#使用tester进行快速测试\n", | |||||
"tester = Tester(test_data, model, metrics=AccuracyMetric())\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,622 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# 使用 Callback 自定义你的训练过程" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"- 什么是 Callback\n", | |||||
"- 使用 Callback \n", | |||||
"- 一些常用的 Callback\n", | |||||
"- 自定义实现 Callback" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"什么是Callback\n", | |||||
"------\n", | |||||
"\n", | |||||
"Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n", | |||||
"\n", | |||||
"fastNLP 中提供了很多常用的 Callback ,开箱即用。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"使用 Callback\n", | |||||
" ------\n", | |||||
"\n", | |||||
"使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": { | |||||
"ExecuteTime": { | |||||
"end_time": "2019-09-17T07:34:46.465871Z", | |||||
"start_time": "2019-09-17T07:34:30.648758Z" | |||||
} | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\ttest has 1200 instances.\n", | |||||
"\ttrain has 9600 instances.\n", | |||||
"\tdev has 1200 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\tchars has 4409 entries.\n", | |||||
"\ttarget has 2 entries.\n", | |||||
"\n", | |||||
"training epochs started 2019-09-17-03-34-34\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.1 seconds!\n", | |||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n", | |||||
"AccuracyMetric: acc=0.863333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.11 seconds!\n", | |||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n", | |||||
"AccuracyMetric: acc=0.886667\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.1 seconds!\n", | |||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n", | |||||
"AccuracyMetric: acc=0.890833\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:3/Step:900, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.890833\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import (Callback, EarlyStopCallback,\n", | |||||
" Trainer, CrossEntropyLoss, AccuracyMetric)\n", | |||||
"from fastNLP.models import CNNText\n", | |||||
"import torch.cuda\n", | |||||
"\n", | |||||
"# prepare data\n", | |||||
"def get_data():\n", | |||||
" from fastNLP.io import ChnSentiCorpPipe as pipe\n", | |||||
" data = pipe().process_from_file()\n", | |||||
" print(data)\n", | |||||
" data.rename_field('chars', 'words')\n", | |||||
" train_data = data.datasets['train']\n", | |||||
" dev_data = data.datasets['dev']\n", | |||||
" test_data = data.datasets['test']\n", | |||||
" vocab = data.vocabs['words']\n", | |||||
" tgt_vocab = data.vocabs['target']\n", | |||||
" return train_data, dev_data, test_data, vocab, tgt_vocab\n", | |||||
"\n", | |||||
"# prepare model\n", | |||||
"train_data, dev_data, _, vocab, tgt_vocab = get_data()\n", | |||||
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", | |||||
"model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n", | |||||
"\n", | |||||
"# define callback\n", | |||||
"callbacks=[EarlyStopCallback(5)]\n", | |||||
"\n", | |||||
"# pass callbacks to Trainer\n", | |||||
"def train_with_callback(cb_list):\n", | |||||
" trainer = Trainer(\n", | |||||
" device=device,\n", | |||||
" n_epochs=3,\n", | |||||
" model=model, \n", | |||||
" train_data=train_data, \n", | |||||
" dev_data=dev_data, \n", | |||||
" loss=CrossEntropyLoss(), \n", | |||||
" metrics=AccuracyMetric(), \n", | |||||
" callbacks=cb_list, \n", | |||||
" check_code_level=-1\n", | |||||
" )\n", | |||||
" trainer.train()\n", | |||||
"\n", | |||||
"train_with_callback(callbacks)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"fastNLP 中的 Callback\n", | |||||
"-------\n", | |||||
"fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": { | |||||
"ExecuteTime": { | |||||
"end_time": "2019-09-17T07:35:02.182727Z", | |||||
"start_time": "2019-09-17T07:34:49.443863Z" | |||||
} | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2019-09-17-03-34-49\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.13 seconds!\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.12 seconds!\n", | |||||
"Evaluation on data-test:\n", | |||||
"AccuracyMetric: acc=0.890833\n", | |||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n", | |||||
"AccuracyMetric: acc=0.890833\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.09 seconds!\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.09 seconds!\n", | |||||
"Evaluation on data-test:\n", | |||||
"AccuracyMetric: acc=0.8875\n", | |||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n", | |||||
"AccuracyMetric: acc=0.8875\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.11 seconds!\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.1 seconds!\n", | |||||
"Evaluation on data-test:\n", | |||||
"AccuracyMetric: acc=0.885\n", | |||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n", | |||||
"AccuracyMetric: acc=0.885\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:1/Step:300, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.890833\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n", | |||||
"callbacks = [\n", | |||||
" EarlyStopCallback(5),\n", | |||||
" GradientClipCallback(clip_value=5, clip_type='value'),\n", | |||||
" EvaluateCallback(dev_data)\n", | |||||
"]\n", | |||||
"\n", | |||||
"train_with_callback(callbacks)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"自定义 Callback\n", | |||||
"------\n", | |||||
"\n", | |||||
"这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n", | |||||
"\n", | |||||
"#### 创建 Callback\n", | |||||
" \n", | |||||
"要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n", | |||||
"\n", | |||||
"这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n", | |||||
"\n", | |||||
"#### 指定 Callback 调用的阶段\n", | |||||
" \n", | |||||
"Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n", | |||||
"\n", | |||||
"这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n", | |||||
"\n", | |||||
"#### 使用 Callback 的属性访问 Trainer 的内部信息\n", | |||||
" \n", | |||||
"为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n", | |||||
"\n", | |||||
"这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": { | |||||
"ExecuteTime": { | |||||
"end_time": "2019-09-17T07:43:10.907139Z", | |||||
"start_time": "2019-09-17T07:42:58.488177Z" | |||||
} | |||||
}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"training epochs started 2019-09-17-03-42-58\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.11 seconds!\n", | |||||
"Evaluation on dev at Epoch 1/3. Step:300/900: \n", | |||||
"AccuracyMetric: acc=0.883333\n", | |||||
"\n", | |||||
"Avg loss at epoch 1, 0.100254\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.1 seconds!\n", | |||||
"Evaluation on dev at Epoch 2/3. Step:600/900: \n", | |||||
"AccuracyMetric: acc=0.8775\n", | |||||
"\n", | |||||
"Avg loss at epoch 2, 0.183511\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 0.13 seconds!\n", | |||||
"Evaluation on dev at Epoch 3/3. Step:900/900: \n", | |||||
"AccuracyMetric: acc=0.875833\n", | |||||
"\n", | |||||
"Avg loss at epoch 3, 0.257103\n", | |||||
"\r\n", | |||||
"In Epoch:1/Step:300, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.883333\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Callback\n", | |||||
"from fastNLP import logger\n", | |||||
"\n", | |||||
"class MyCallBack(Callback):\n", | |||||
" \"\"\"Print average loss in each epoch\"\"\"\n", | |||||
" def __init__(self):\n", | |||||
" super().__init__()\n", | |||||
" self.total_loss = 0\n", | |||||
" self.start_step = 0\n", | |||||
" \n", | |||||
" def on_backward_begin(self, loss):\n", | |||||
" self.total_loss += loss.item()\n", | |||||
" \n", | |||||
" def on_epoch_end(self):\n", | |||||
" n_steps = self.step - self.start_step\n", | |||||
" avg_loss = self.total_loss / n_steps\n", | |||||
" logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n", | |||||
" self.start_step = self.step\n", | |||||
"\n", | |||||
"callbacks = [MyCallBack()]\n", | |||||
"train_with_callback(callbacks)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.7.3" | |||||
}, | |||||
"varInspector": { | |||||
"cols": { | |||||
"lenName": 16, | |||||
"lenType": 16, | |||||
"lenVar": 40 | |||||
}, | |||||
"kernels_config": { | |||||
"python": { | |||||
"delete_cmd_postfix": "", | |||||
"delete_cmd_prefix": "del ", | |||||
"library": "var_list.py", | |||||
"varRefreshCmd": "print(var_dic_list())" | |||||
}, | |||||
"r": { | |||||
"delete_cmd_postfix": ") ", | |||||
"delete_cmd_prefix": "rm(", | |||||
"library": "var_list.r", | |||||
"varRefreshCmd": "cat(var_dic_list()) " | |||||
} | |||||
}, | |||||
"types_to_exclude": [ | |||||
"module", | |||||
"function", | |||||
"builtin_function_or_method", | |||||
"instance", | |||||
"_Feature" | |||||
], | |||||
"window_display": false | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 4 | |||||
} |
@@ -0,0 +1,912 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# 序列标注\n", | |||||
"\n", | |||||
"这一部分的内容主要展示如何使用fastNLP实现序列标注(Sequence labeling)任务。您可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 在阅读这篇教程前,希望您已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让您进一步熟悉fastNLP的使用。\n", | |||||
"\n", | |||||
"## 命名实体识别(name entity recognition, NER)\n", | |||||
"\n", | |||||
"命名实体识别任务是从文本中抽取出具有特殊意义或者指代性非常强的实体,通常包括人名、地名、机构名和时间等。 如下面的例子中\n", | |||||
"\n", | |||||
"*我来自复旦大学*\n", | |||||
"\n", | |||||
"其中“复旦大学”就是一个机构名,命名实体识别就是要从中识别出“复旦大学”这四个字是一个整体,且属于机构名这个类别。这个问题在实际做的时候会被 转换为序列标注问题\n", | |||||
"\n", | |||||
"针对\"我来自复旦大学\"这句话,我们的预测目标将是[O, O, O, B-ORG, I-ORG, I-ORG, I-ORG],其中O表示out,即不是一个实体,B-ORG是ORG( organization的缩写)这个类别的开头(Begin),I-ORG是ORG类别的中间(Inside)。\n", | |||||
"\n", | |||||
"在本tutorial中我们将通过fastNLP尝试写出一个能够执行以上任务的模型。\n", | |||||
"\n", | |||||
"## 载入数据\n", | |||||
"\n", | |||||
"fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您可以通过《使用Loader和Pipe处理数据》了解如何使用fastNLP提供的数据加载函数。下面我们以微博命名实体任务来演示一下在fastNLP进行序列标注任务。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n", | |||||
"| raw_chars | target | chars | seq_len |\n", | |||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n", | |||||
"| ['科', '技', '全', '方', '位',... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... | [792, 1015, 156, 198, 291, 714... | 26 |\n", | |||||
"| ['对', ',', '输', '给', '一',... | [0, 0, 0, 0, 0, 0, 3, 1, 0, 0,... | [123, 2, 1205, 115, 8, 24, 101... | 15 |\n", | |||||
"+-----------------------------------+-----------------------------------+-----------------------------------+---------+\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.io import WeiboNERPipe\n", | |||||
"data_bundle = WeiboNERPipe().process_from_file()\n", | |||||
"print(data_bundle.get_dataset('train')[:2])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 模型构建\n", | |||||
"\n", | |||||
"首先选择需要使用的Embedding类型。关于Embedding的相关说明可以参见《使用Embedding模块将文本转成向量》。 在这里我们使用通过word2vec预训练的中文汉字embedding。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Found 3321 out of 3471 words in the pre-training embedding.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"\n", | |||||
"embed = StaticEmbedding(vocab=data_bundle.get_vocab('chars'), model_dir_or_name='cn-char-fastnlp-100d')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"选择好Embedding之后,我们可以使用fastNLP中自带的 fastNLP.models.BiLSTMCRF 作为模型。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.models import BiLSTMCRF\n", | |||||
"\n", | |||||
"data_bundle.rename_field('chars', 'words') # 这是由于BiLSTMCRF模型的forward函数接受的words,而不是chars,所以需要把这一列重新命名\n", | |||||
"model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n", | |||||
" target_vocab=data_bundle.get_vocab('target'))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 进行训练\n", | |||||
"下面我们选择用来评估模型的metric,以及优化用到的优化函数。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP import SpanFPreRecMetric\n", | |||||
"from torch.optim import Adam\n", | |||||
"from fastNLP import LossInForward\n", | |||||
"\n", | |||||
"metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n", | |||||
"optimizer = Adam(model.parameters(), lr=1e-2)\n", | |||||
"loss = LossInForward()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"使用Trainer进行训练, 您可以通过修改 device 的值来选择显卡。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2020-02-27-13-53-24\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=430.0), HTML(value='')), layout=Layout(di…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.89 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/10. Step:43/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.067797, pre=0.192771, rec=0.041131\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.9 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/10. Step:86/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.344086, pre=0.568047, rec=0.246787\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.88 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/10. Step:129/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.446701, pre=0.653465, rec=0.339332\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.81 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 4/10. Step:172/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.479871, pre=0.642241, rec=0.383033\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.91 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 5/10. Step:215/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.486312, pre=0.650862, rec=0.388175\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.87 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 6/10. Step:258/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.86 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 7/10. Step:301/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.430335, pre=0.685393, rec=0.313625\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.82 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 8/10. Step:344/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.477759, pre=0.665138, rec=0.372751\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.81 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 9/10. Step:387/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.500759, pre=0.611111, rec=0.424165\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=9.0), HTML(value='')), layout=Layout(disp…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.8 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 10/10. Step:430/430: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.496025, pre=0.65, rec=0.401028\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:6/Step:258, got best dev performance:\n", | |||||
"SpanFPreRecMetric: f=0.541401, pre=0.711297, rec=0.437018\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'best_eval': {'SpanFPreRecMetric': {'f': 0.541401,\n", | |||||
" 'pre': 0.711297,\n", | |||||
" 'rec': 0.437018}},\n", | |||||
" 'best_epoch': 6,\n", | |||||
" 'best_step': 258,\n", | |||||
" 'seconds': 121.39}" | |||||
] | |||||
}, | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"device= 0 if torch.cuda.is_available() else 'cpu'\n", | |||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer,\n", | |||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n", | |||||
"trainer.train()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 进行测试\n", | |||||
"训练结束之后过,可以通过 Tester 测试其在测试集上的性能" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 1.54 seconds!\n", | |||||
"[tester] \n", | |||||
"SpanFPreRecMetric: f=0.439024, pre=0.685279, rec=0.322967\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'SpanFPreRecMetric': {'f': 0.439024, 'pre': 0.685279, 'rec': 0.322967}}" | |||||
] | |||||
}, | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Tester\n", | |||||
"tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 使用更强的Bert做序列标注\n", | |||||
"\n", | |||||
"在fastNLP使用Bert进行任务,您只需要把fastNLP.embeddings.StaticEmbedding 切换为 fastNLP.embeddings.BertEmbedding(可修改 device 选择显卡)。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n", | |||||
"Start to generate word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 3384 words out of 3471.\n", | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 26]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"training epochs started 2020-02-27-13-58-51\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1130.0), HTML(value='')), layout=Layout(d…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.7 seconds!\n", | |||||
"Evaluation on dev at Epoch 1/10. Step:113/1130: \n", | |||||
"SpanFPreRecMetric: f=0.008114, pre=0.019231, rec=0.005141\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.49 seconds!\n", | |||||
"Evaluation on dev at Epoch 2/10. Step:226/1130: \n", | |||||
"SpanFPreRecMetric: f=0.467866, pre=0.467866, rec=0.467866\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.6 seconds!\n", | |||||
"Evaluation on dev at Epoch 3/10. Step:339/1130: \n", | |||||
"SpanFPreRecMetric: f=0.566879, pre=0.482821, rec=0.686375\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.56 seconds!\n", | |||||
"Evaluation on dev at Epoch 4/10. Step:452/1130: \n", | |||||
"SpanFPreRecMetric: f=0.651972, pre=0.59408, rec=0.722365\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 2.69 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 5/10. Step:565/1130: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.640909, pre=0.574338, rec=0.724936\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.52 seconds!\n", | |||||
"Evaluation on dev at Epoch 6/10. Step:678/1130: \n", | |||||
"SpanFPreRecMetric: f=0.661836, pre=0.624146, rec=0.70437\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.67 seconds!\n", | |||||
"Evaluation on dev at Epoch 7/10. Step:791/1130: \n", | |||||
"SpanFPreRecMetric: f=0.683429, pre=0.615226, rec=0.768638\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 2.37 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 8/10. Step:904/1130: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.674699, pre=0.634921, rec=0.719794\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Evaluate data in 2.42 seconds!\n", | |||||
"Evaluation on dev at Epoch 9/10. Step:1017/1130: \n", | |||||
"SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=23.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 2.46 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 10/10. Step:1130/1130: \n", | |||||
"\r", | |||||
"SpanFPreRecMetric: f=0.686845, pre=0.62766, rec=0.758355\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:9/Step:1017, got best dev performance:\n", | |||||
"SpanFPreRecMetric: f=0.693878, pre=0.650901, rec=0.742931\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=17.0), HTML(value='')), layout=Layout(dis…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 1.96 seconds!\n", | |||||
"[tester] \n", | |||||
"SpanFPreRecMetric: f=0.626561, pre=0.596112, rec=0.660287\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'SpanFPreRecMetric': {'f': 0.626561, 'pre': 0.596112, 'rec': 0.660287}}" | |||||
] | |||||
}, | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"\n", | |||||
"from fastNLP.io import WeiboNERPipe\n", | |||||
"data_bundle = WeiboNERPipe().process_from_file()\n", | |||||
"data_bundle.rename_field('chars', 'words')\n", | |||||
"\n", | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"embed = BertEmbedding(vocab=data_bundle.get_vocab('words'), model_dir_or_name='cn')\n", | |||||
"model = BiLSTMCRF(embed=embed, num_classes=len(data_bundle.get_vocab('target')), num_layers=1, hidden_size=200, dropout=0.5,\n", | |||||
" target_vocab=data_bundle.get_vocab('target'))\n", | |||||
"\n", | |||||
"from fastNLP import SpanFPreRecMetric\n", | |||||
"from torch.optim import Adam\n", | |||||
"from fastNLP import LossInForward\n", | |||||
"metric = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))\n", | |||||
"optimizer = Adam(model.parameters(), lr=2e-5)\n", | |||||
"loss = LossInForward()\n", | |||||
"\n", | |||||
"from fastNLP import Trainer\n", | |||||
"import torch\n", | |||||
"device= 5 if torch.cuda.is_available() else 'cpu'\n", | |||||
"trainer = Trainer(data_bundle.get_dataset('train'), model, loss=loss, optimizer=optimizer, batch_size=12,\n", | |||||
" dev_data=data_bundle.get_dataset('dev'), metrics=metric, device=device)\n", | |||||
"trainer.train()\n", | |||||
"\n", | |||||
"from fastNLP import Tester\n", | |||||
"tester = Tester(data_bundle.get_dataset('test'), model, metrics=metric)\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python Now", | |||||
"language": "python", | |||||
"name": "now" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.8.0" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,834 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 文本分类(Text classification)\n", | |||||
"文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别,文本情绪分类等。\n", | |||||
"\n", | |||||
"Example:: \n", | |||||
"1,商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!\n", | |||||
"\n", | |||||
"\n", | |||||
"其中开头的1是只这条评论的标签,表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压,当然也可以通过fastNLP自动下载该数据。\n", | |||||
"\n", | |||||
"数据中的内容如下图所示。接下来,我们将用fastNLP在这个数据上训练一个分类网络。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"![jupyter](./cn_cls_example.png)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 步骤\n", | |||||
"一共有以下的几个步骤 \n", | |||||
"(1) 读取数据 \n", | |||||
"(2) 预处理数据 \n", | |||||
"(3) 选择预训练词向量 \n", | |||||
"(4) 创建模型 \n", | |||||
"(5) 训练模型 " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (1) 读取数据\n", | |||||
"fastNLP提供多种数据的自动下载与自动加载功能,对于这里我们要用到的数据,我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.io import ChnSentiCorpLoader\n", | |||||
"\n", | |||||
"loader = ChnSentiCorpLoader() # 初始化一个中文情感分类的loader\n", | |||||
"data_dir = loader.download() # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n", | |||||
"data_bundle = loader.load(data_dir) # 这一行代码将从{data_dir}处读取数据至DataBundle" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"DataBundle的相关介绍,可以参考\\ref{}。我们可以打印该data_bundle的基本信息。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 2, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\tdev has 1200 instances.\n", | |||||
"\ttrain has 9600 instances.\n", | |||||
"\ttest has 1200 instances.\n", | |||||
"In total 0 vocabs:\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"print(data_bundle)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"可以看出,该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码,我们可以查看DataSet的基本情况" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", | |||||
"'target': 1 type=str},\n", | |||||
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", | |||||
"'target': 1 type=str})\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"print(data_bundle.get_dataset('train')[:2]) # 查看Train集前两个sample" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (2) 预处理数据\n", | |||||
"在NLP任务中,预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index \n", | |||||
"\n", | |||||
"fastNLP中也提供了多种数据集的处理类,这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 3, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from fastNLP.io import ChnSentiCorpPipe\n", | |||||
"\n", | |||||
"pipe = ChnSentiCorpPipe()\n", | |||||
"data_bundle = pipe.process(data_bundle) # 所有的Pipe都实现了process()方法,且输入输出都为DataBundle类型" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 4, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"In total 3 datasets:\n", | |||||
"\tdev has 1200 instances.\n", | |||||
"\ttrain has 9600 instances.\n", | |||||
"\ttest has 1200 instances.\n", | |||||
"In total 2 vocabs:\n", | |||||
"\tchars has 4409 entries.\n", | |||||
"\ttarget has 2 entries.\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"print(data_bundle) # 打印data_bundle,查看其变化" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 5, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"DataSet({'raw_chars': 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般 type=str,\n", | |||||
"'target': 1 type=int,\n", | |||||
"'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n", | |||||
"'seq_len': 106 type=int},\n", | |||||
"{'raw_chars': 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 type=str,\n", | |||||
"'target': 1 type=int,\n", | |||||
"'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n", | |||||
"'seq_len': 56 type=int})\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"print(data_bundle.get_dataset('train')[:2])" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"新增了一列为数字列表的chars,以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的,我们可以打印一下Vocabulary看一下里面的内容。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 6, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Vocabulary(['选', '择', '珠', '江', '花']...)\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"char_vocab = data_bundle.get_vocab('chars')\n", | |||||
"print(char_vocab)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"Vocabulary是一个记录着词语与index之间映射关系的类,比如" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 7, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"'选'的index是338\n", | |||||
"index:338对应的汉字是选\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"index = char_vocab.to_index('选')\n", | |||||
"print(\"'选'的index是{}\".format(index)) # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n", | |||||
"print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (3) 选择预训练词向量 \n", | |||||
"由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能,所以在训练具体任务前,选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例,之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d',fastNLP将自动下载该embedding至本地缓存,fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 8, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Found 4321 out of 4409 words in the pre-training embedding.\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP.embeddings import StaticEmbedding\n", | |||||
"\n", | |||||
"word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (4) 创建模型\n", | |||||
"这里我们使用到的模型结构如下所示,补图" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 9, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from torch import nn\n", | |||||
"from fastNLP.modules import LSTM\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"# 定义模型\n", | |||||
"class BiLSTMMaxPoolCls(nn.Module):\n", | |||||
" def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n", | |||||
" super().__init__()\n", | |||||
" self.embed = embed\n", | |||||
" \n", | |||||
" self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n", | |||||
" batch_first=True, bidirectional=True)\n", | |||||
" self.dropout_layer = nn.Dropout(dropout)\n", | |||||
" self.fc = nn.Linear(hidden_size, num_classes)\n", | |||||
" \n", | |||||
" def forward(self, chars, seq_len): # 这里的名称必须和DataSet中相应的field对应,比如之前我们DataSet中有chars,这里就必须为chars\n", | |||||
" # chars:[batch_size, max_len]\n", | |||||
" # seq_len: [batch_size, ]\n", | |||||
" chars = self.embed(chars)\n", | |||||
" outputs, _ = self.lstm(chars, seq_len)\n", | |||||
" outputs = self.dropout_layer(outputs)\n", | |||||
" outputs, _ = torch.max(outputs, dim=1)\n", | |||||
" outputs = self.fc(outputs)\n", | |||||
" \n", | |||||
" return {'pred':outputs} # [batch_size,], 返回值必须是dict类型,且预测值的key建议设为pred\n", | |||||
"\n", | |||||
"# 初始化模型\n", | |||||
"model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### (5) 训练模型\n", | |||||
"fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型),梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"Evaluate data in 0.01 seconds!\n", | |||||
"training epochs started 2019-09-03-23-57-10\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/10. Step:300/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.81\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/10. Step:600/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8675\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/10. Step:900/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.878333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 4/10. Step:1200/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.873333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 5/10. Step:1500/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.878333\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.42 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 6/10. Step:1800/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.895833\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.44 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 7/10. Step:2100/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8975\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 8/10. Step:2400/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.894167\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.48 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 9/10. Step:2700/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.8875\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.43 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 10/10. Step:3000/3000: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.895833\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:7/Step:2100, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.8975\n", | |||||
"Reloaded the best model.\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 0.34 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.8975\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.8975}}" | |||||
] | |||||
}, | |||||
"execution_count": 10, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"from fastNLP import Trainer\n", | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"from torch.optim import Adam\n", | |||||
"from fastNLP import AccuracyMetric\n", | |||||
"\n", | |||||
"loss = CrossEntropyLoss()\n", | |||||
"optimizer = Adam(model.parameters(), lr=0.001)\n", | |||||
"metric = AccuracyMetric()\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n", | |||||
"\n", | |||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n", | |||||
" optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n", | |||||
" metrics=metric, device=device)\n", | |||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n", | |||||
"\n", | |||||
"# 在测试集上测试一下模型的性能\n", | |||||
"from fastNLP import Tester\n", | |||||
"print(\"Performance on test is:\")\n", | |||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 使用Bert进行文本分类" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n", | |||||
"Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n", | |||||
"Start to generating word pieces for word.\n", | |||||
"Found(Or segment into word pieces) 4286 words out of 4409.\n", | |||||
"input fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"target fields after batch(if batch size is 2):\n", | |||||
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", | |||||
"\n", | |||||
"Evaluate data in 0.05 seconds!\n", | |||||
"training epochs started 2019-09-04-00-02-37\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.89 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 1/3. Step:1200/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.9\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.92 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 2/3. Step:2400/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.904167\n", | |||||
"\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 15.91 seconds!\n", | |||||
"\r", | |||||
"Evaluation on dev at Epoch 3/3. Step:3600/3600: \n", | |||||
"\r", | |||||
"AccuracyMetric: acc=0.918333\n", | |||||
"\n", | |||||
"\r\n", | |||||
"In Epoch:3/Step:3600, got best dev performance:\n", | |||||
"AccuracyMetric: acc=0.918333\n", | |||||
"Reloaded the best model.\n", | |||||
"Performance on test is:\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"\r", | |||||
"Evaluate data in 29.24 seconds!\n", | |||||
"[tester] \n", | |||||
"AccuracyMetric: acc=0.919167\n" | |||||
] | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"{'AccuracyMetric': {'acc': 0.919167}}" | |||||
] | |||||
}, | |||||
"execution_count": 12, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"# 只需要切换一下Embedding即可\n", | |||||
"from fastNLP.embeddings import BertEmbedding\n", | |||||
"\n", | |||||
"# 这里为了演示一下效果,所以默认Bert不更新权重\n", | |||||
"bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n", | |||||
"model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n", | |||||
"\n", | |||||
"\n", | |||||
"import torch\n", | |||||
"from fastNLP import Trainer\n", | |||||
"from fastNLP import CrossEntropyLoss\n", | |||||
"from torch.optim import Adam\n", | |||||
"from fastNLP import AccuracyMetric\n", | |||||
"\n", | |||||
"loss = CrossEntropyLoss()\n", | |||||
"optimizer = Adam(model.parameters(), lr=2e-5)\n", | |||||
"metric = AccuracyMetric()\n", | |||||
"device = 0 if torch.cuda.is_available() else 'cpu' # 如果有gpu的话在gpu上运行,训练速度会更快\n", | |||||
"\n", | |||||
"trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n", | |||||
" optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n", | |||||
" metrics=metric, device=device, n_epochs=3)\n", | |||||
"trainer.train() # 开始训练,训练完成之后默认会加载在dev上表现最好的模型\n", | |||||
"\n", | |||||
"# 在测试集上测试一下模型的性能\n", | |||||
"from fastNLP import Tester\n", | |||||
"print(\"Performance on test is:\")\n", | |||||
"tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n", | |||||
"tester.test()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.6.7" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -226,4 +226,6 @@ Bert自从在 `BERT: Pre-training of Deep Bidirectional Transformers for Languag | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/extend_1_bert_embedding.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/extend_1_bert_embedding.ipynb" download="extend_1_bert_embedding.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -167,4 +167,6 @@ fastNLP中field的命名习惯 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_1_data_preprocess.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_1_data_preprocess.ipynb" download="tutorial_1_data_preprocess.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -135,4 +135,6 @@ fastNLP中的Vocabulary | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_2_vocabulary.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_2_vocabulary.ipynb" download="tutorial_2_vocabulary.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -457,4 +457,6 @@ fastNLP通过在 :class:`~fastNLP.embeddings.StaticEmbedding` 增加了一个min | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_3_embedding.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_3_embedding.ipynb" download="tutorial_3_embedding.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -214,4 +214,6 @@ Part V: 不同格式类型的基础Loader | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_4_load_dataset.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_4_load_dataset.ipynb" download="tutorial_4_load_dataset.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -243,4 +243,6 @@ | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_5_loss_optimizer.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_5_loss_optimizer.ipynb" download="tutorial_5_loss_optimizer.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -418,4 +418,6 @@ Dataset个性化padding | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_6_datasetiter.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_6_datasetiter.ipynb" download="tutorial_6_datasetiter.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -130,4 +130,6 @@ self.get_metric将统计当前的评价指标并返回评价结果, 返回值需 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_7_metrics.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_7_metrics.ipynb" download="tutorial_7_metrics.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -188,4 +188,6 @@ FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_8_modules_models.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_8_modules_models.ipynb" download="tutorial_8_modules_models.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -135,4 +135,6 @@ fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/tutorial_9_callback.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/tutorial_9_callback.ipynb" download="tutorial_9_callback.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -203,4 +203,6 @@ fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的,您 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.ipynb" download="序列标注.ipynb">点击下载 IPython Notebook 文件</a><hr> |
@@ -373,4 +373,6 @@ fastNLP提供了Trainer对象来组织训练过程,包括完成loss计算(所 | |||||
代码下载 | 代码下载 | ||||
---------------------------------- | ---------------------------------- | ||||
`点击下载 IPython Notebook 文件 <https://sourcegraph.com/github.com/fastnlp/fastNLP@master/-/raw/tutorials/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.ipynb>`_) | |||||
.. raw:: html | |||||
<a href="../_static/notebooks/%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB.ipynb" download="文本分类.ipynb">点击下载 IPython Notebook 文件 </a> |