Browse Source

Create bert_embedding_tutorial.ipynb

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
33cbb5b540
1 changed files with 470 additions and 0 deletions
  1. +470
    -0
      tutorials/bert_embedding_tutorial.ipynb

+ 470
- 0
tutorials/bert_embedding_tutorial.ipynb View File

@@ -0,0 +1,470 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BertEmbedding的各种用法\n",
"fastNLP的BertEmbedding以pytorch-transformer.BertModel的代码为基础,是一个使用BERT对words进行编码的Embedding。\n",
"\n",
"使用BertEmbedding和fastNLP.models.bert里面模型可以搭建BERT应用到五种下游任务的模型。\n",
"\n",
"*预训练好的Embedding参数及数据集的介绍和自动下载功能见 [Embedding教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) 和 [数据处理教程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)。*\n",
"\n",
"## 1. BERT for Squence Classification\n",
"在文本分类任务中,我们采用SST数据集作为例子来介绍BertEmbedding的使用方法。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"import torch\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"In total 3 datasets:\n",
"\ttest has 2210 instances.\n",
"\ttrain has 8544 instances.\n",
"\tdev has 1101 instances.\n",
"In total 2 vocabs:\n",
"\twords has 21701 entries.\n",
"\ttarget has 5 entries."
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 载入数据集\n",
"from fastNLP.io import SSTPipe\n",
"data_bundle = SSTPipe(subtree=False, train_subtree=False, lower=False, tokenizer='raw').process_from_file()\n",
"data_bundle"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
"Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
"Start to generate word pieces for word.\n",
"Found(Or segment into word pieces) 21701 words out of 21701.\n"
]
}
],
"source": [
"# 载入BertEmbedding\n",
"from fastNLP.embeddings import BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 载入模型\n",
"from fastNLP.models import BertForSequenceClassification\n",
"model = BertForSequenceClassification(embed, len(data_bundle.get_vocab('target')))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 37]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2019-09-11-17-35-26\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=268), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 2.08 seconds!\n",
"Evaluation on dev at Epoch 1/2. Step:134/268: \n",
"AccuracyMetric: acc=0.459582\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 2.2 seconds!\n",
"Evaluation on dev at Epoch 2/2. Step:268/268: \n",
"AccuracyMetric: acc=0.468665\n",
"\n",
"\n",
"In Epoch:2/Step:268, got best dev performance:\n",
"AccuracyMetric: acc=0.468665\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.468665}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 268,\n",
" 'seconds': 114.5}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练模型\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
" optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
" loss=CrossEntropyLoss(), device=[0],\n",
" batch_size=64, dev_data=data_bundle.get_dataset('dev'), \n",
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"Evaluate data in 4.52 seconds!\n",
"[tester] \n",
"AccuracyMetric: acc=0.504072\n"
]
},
{
"data": {
"text/plain": [
"{'AccuracyMetric': {'acc': 0.504072}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 测试结果并删除模型\n",
"from fastNLP import Tester\n",
"tester = Tester(data_bundle.get_dataset('test'), model, batch_size=128, metrics=AccuracyMetric())\n",
"tester.test()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 2. BERT for Sentence Matching\n",
"在Matching任务中,我们采用RTE数据集作为例子来介绍BertEmbedding的使用方法。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"In total 3 datasets:\n",
"\ttest has 3000 instances.\n",
"\ttrain has 2490 instances.\n",
"\tdev has 277 instances.\n",
"In total 2 vocabs:\n",
"\twords has 41281 entries.\n",
"\ttarget has 2 entries."
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 载入数据集\n",
"from fastNLP.io import RTEBertPipe\n",
"data_bundle = RTEBertPipe(lower=False, tokenizer='raw').process_from_file()\n",
"data_bundle"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading vocabulary file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/vocab.txt\n",
"Load pre-trained BERT parameters from file /remote-home/source/fastnlp_caches/embedding/bert-base-cased/pytorch_model.bin.\n",
"Start to generate word pieces for word.\n",
"Found(Or segment into word pieces) 41279 words out of 41281.\n"
]
}
],
"source": [
"# 载入BertEmbedding\n",
"from fastNLP.embeddings import BertEmbedding\n",
"embed = BertEmbedding(data_bundle.get_vocab('words'), model_dir_or_name='en-base-cased', include_cls_sep=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# 载入模型\n",
"from fastNLP.models import BertForSentenceMatching\n",
"model = BertForSentenceMatching(embed, len(data_bundle.get_vocab('target')))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input fields after batch(if batch size is 2):\n",
"\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 45]) \n",
"\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"target fields after batch(if batch size is 2):\n",
"\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
"\n",
"training epochs started 2019-09-11-17-37-36\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=312), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 1.72 seconds!\n",
"Evaluation on dev at Epoch 1/2. Step:156/312: \n",
"AccuracyMetric: acc=0.624549\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=18), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 1.74 seconds!\n",
"Evaluation on dev at Epoch 2/2. Step:312/312: \n",
"AccuracyMetric: acc=0.649819\n",
"\n",
"\n",
"In Epoch:2/Step:312, got best dev performance:\n",
"AccuracyMetric: acc=0.649819\n",
"Reloaded the best model.\n"
]
},
{
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.649819}},\n",
" 'best_epoch': 2,\n",
" 'best_step': 312,\n",
" 'seconds': 109.87}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 训练模型\n",
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam\n",
"trainer = Trainer(data_bundle.get_dataset('train'), model, \n",
" optimizer=Adam(model_params=model.parameters(), lr=2e-5), \n",
" loss=CrossEntropyLoss(), device=[0],\n",
" batch_size=16, dev_data=data_bundle.get_dataset('dev'), \n",
" metrics=AccuracyMetric(), n_epochs=2, print_every=1)\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Loading…
Cancel
Save