{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用Modules和Models快速搭建自定义模型\n", "\n", "modules 和 models 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。 下面我们会分三节介绍编写构建模型的具体方法。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们首先准备好和上篇教程一样的基础实验代码" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fastNLP.io import SST2Pipe\n", "from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n", "import torch\n", "\n", "databundle = SST2Pipe().process_from_file()\n", "vocab = databundle.get_vocab('words')\n", "train_data = databundle.get_dataset('train')[:5000]\n", "train_data, test_data = train_data.split(0.015)\n", "dev_data = databundle.get_dataset('dev')\n", "\n", "loss = CrossEntropyLoss()\n", "metric = AccuracyMetric()\n", "device = 0 if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 models 中的模型\n", "\n", "fastNLP 在 models 模块中内置了如 CNNText 、 SeqLabeling 等完整的模型,以供用户直接使用。 以文本分类的任务为例,我们从 models 中导入 CNNText 模型,用它进行训练。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", "training epochs started 2020-02-28-00-56-04\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.22 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", "AccuracyMetric: acc=0.760321\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.29 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", "AccuracyMetric: acc=0.727064\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.48 seconds!\n", "\r", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "\r", "AccuracyMetric: acc=0.758028\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.24 seconds!\n", "\r", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "\r", "AccuracyMetric: acc=0.759174\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.47 seconds!\n", "\r", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "\r", "AccuracyMetric: acc=0.743119\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.22 seconds!\n", "\r", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "\r", "AccuracyMetric: acc=0.756881\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.21 seconds!\n", "\r", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "\r", "AccuracyMetric: acc=0.752294\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.21 seconds!\n", "\r", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "\r", "AccuracyMetric: acc=0.756881\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.15 seconds!\n", "\r", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "\r", "AccuracyMetric: acc=0.75344\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.12 seconds!\n", "\r", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "\r", "AccuracyMetric: acc=0.752294\n", "\n", "\r\n", "In Epoch:1/Step:154, got best dev performance:\n", "AccuracyMetric: acc=0.760321\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ "{'best_eval': {'AccuracyMetric': {'acc': 0.760321}},\n", " 'best_epoch': 1,\n", " 'best_step': 154,\n", " 'seconds': 29.3}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fastNLP.models import CNNText\n", "\n", "model_cnn = CNNText((len(vocab),100), num_classes=2, dropout=0.1)\n", "\n", "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n", " loss=loss, device=device, model=model_cnn)\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 iPython 环境输入 model_cnn ,我们可以看到 model_cnn 的网络结构" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CNNText(\n", " (embed): Embedding(\n", " (embed): Embedding(16292, 100)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (conv_pool): ConvMaxpool(\n", " (convs): ModuleList(\n", " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", " )\n", " )\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (fc): Linear(in_features=120, out_features=2, bias=True)\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_cnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 nn.torch 编写模型\n", "\n", "FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模型的常见方法不同, 用于 fastNLP 的模型中 forward 函数需要返回一个字典,字典中至少需要包含 pred 这个字段。\n", "\n", "下面是使用 pyTorch 中的 torch.nn 模块编写的文本分类,注意观察代码中标注的向量维度。 由于 pyTorch 使用了约定俗成的维度设置,使得 forward 中需要多次处理维度顺序" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class LSTMText(nn.Module):\n", " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", " super().__init__()\n", "\n", " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", " self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout)\n", " self.fc = nn.Linear(hidden_dim * 2, output_dim)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, words):\n", " # (input) words : (batch_size, seq_len)\n", " words = words.permute(1,0)\n", " # words : (seq_len, batch_size)\n", "\n", " embedded = self.dropout(self.embedding(words))\n", " # embedded : (seq_len, batch_size, embedding_dim)\n", " output, (hidden, cell) = self.lstm(embedded)\n", " # output: (seq_len, batch_size, hidden_dim * 2)\n", " # hidden: (num_layers * 2, batch_size, hidden_dim)\n", " # cell: (num_layers * 2, batch_size, hidden_dim)\n", "\n", " hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)\n", " hidden = self.dropout(hidden)\n", " # hidden: (batch_size, hidden_dim * 2)\n", "\n", " pred = self.fc(hidden.squeeze(0))\n", " # result: (batch_size, output_dim)\n", " return {\"pred\":pred}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们同样可以在 iPython 环境中查看这个模型的网络结构" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LSTMText(\n", " (embedding): Embedding(16292, 100)\n", " (lstm): LSTM(100, 64, num_layers=2, dropout=0.5, bidirectional=True)\n", " (fc): Linear(in_features=128, out_features=2, bias=True)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", ")" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_lstm = LSTMText(len(vocab), 100, 2)\n", "model_lstm " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", "training epochs started 2020-02-28-00-56-34\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.36 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", "AccuracyMetric: acc=0.59289\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.35 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", "AccuracyMetric: acc=0.674312\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.21 seconds!\n", "\r", "Evaluation on dev at Epoch 3/10. Step:462/1540: \n", "\r", "AccuracyMetric: acc=0.724771\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.4 seconds!\n", "\r", "Evaluation on dev at Epoch 4/10. Step:616/1540: \n", "\r", "AccuracyMetric: acc=0.748853\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.24 seconds!\n", "\r", "Evaluation on dev at Epoch 5/10. Step:770/1540: \n", "\r", "AccuracyMetric: acc=0.756881\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.29 seconds!\n", "\r", "Evaluation on dev at Epoch 6/10. Step:924/1540: \n", "\r", "AccuracyMetric: acc=0.741972\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.32 seconds!\n", "\r", "Evaluation on dev at Epoch 7/10. Step:1078/1540: \n", "\r", "AccuracyMetric: acc=0.754587\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.24 seconds!\n", "\r", "Evaluation on dev at Epoch 8/10. Step:1232/1540: \n", "\r", "AccuracyMetric: acc=0.756881\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.28 seconds!\n", "\r", "Evaluation on dev at Epoch 9/10. Step:1386/1540: \n", "\r", "AccuracyMetric: acc=0.740826\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.23 seconds!\n", "\r", "Evaluation on dev at Epoch 10/10. Step:1540/1540: \n", "\r", "AccuracyMetric: acc=0.751147\n", "\n", "\r\n", "In Epoch:5/Step:770, got best dev performance:\n", "AccuracyMetric: acc=0.756881\n", "Reloaded the best model.\n" ] }, { "data": { "text/plain": [ "{'best_eval': {'AccuracyMetric': {'acc': 0.756881}},\n", " 'best_epoch': 5,\n", " 'best_step': 770,\n", " 'seconds': 45.69}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n", " loss=loss, device=device, model=model_lstm)\n", "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 modules 编写模型\n", "\n", "下面我们使用 fastNLP.modules 中的组件来构建同样的网络。由于 fastNLP 统一把 batch_size 放在第一维, 在编写代码的过程中会有一定的便利。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MyText(\n", " (embedding): Embedding(\n", " (embed): Embedding(16292, 100)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " )\n", " (lstm): LSTM(\n", " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n", " )\n", " (mlp): MLP(\n", " (hiddens): ModuleList()\n", " (output): Linear(in_features=128, out_features=2, bias=True)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", ")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from fastNLP.modules import LSTM, MLP\n", "from fastNLP.embeddings import Embedding\n", "\n", "\n", "class MyText(nn.Module):\n", " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", " super().__init__()\n", "\n", " self.embedding = Embedding((vocab_size, embedding_dim))\n", " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", " self.mlp = MLP([hidden_dim*2,output_dim], dropout=dropout)\n", "\n", " def forward(self, words):\n", " embedded = self.embedding(words)\n", " _,(hidden,_) = self.lstm(embedded)\n", " pred = self.mlp(torch.cat((hidden[-1],hidden[-2]),dim=1))\n", " return {\"pred\":pred}\n", " \n", "model_text = MyText(len(vocab), 100, 2)\n", "model_text" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input fields after batch(if batch size is 2):\n", "\twords: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) \n", "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "target fields after batch(if batch size is 2):\n", "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n", "\n", "training epochs started 2020-02-28-00-57-19\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "16a35f2b0ef0457dae15c5f240a19a3a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1540.0), HTML(value='')), layout=Layout(d…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.38 seconds!\n", "\r", "Evaluation on dev at Epoch 1/10. Step:154/1540: \n", "\r", "AccuracyMetric: acc=0.767202\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=28.0), HTML(value='')), layout=Layout(dis…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Evaluate data in 0.22 seconds!\n", "\r", "Evaluation on dev at Epoch 2/10. Step:308/1540: \n", "\r", "AccuracyMetric: acc=0.743119\n", "\n" ] } ], "source": [ "trainer = Trainer(train_data=train_data, dev_data=dev_data, metrics=metric,\n", " loss=loss, device=device, model=model_lstm)\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python Now", "language": "python", "name": "now" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 2 }