diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 7566c02a..ff8151b9 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -17,7 +17,7 @@ "\n", " 2.1 collator 的概念与使用\n", "\n", - " 2.2 sampler 的概念与使用" + " 2.2 结合 datasets 框架" ] }, { @@ -71,8 +71,8 @@ "| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", "| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", "| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", - "| `sampler` | √ | √ | ? | 默认`None` |\n", - "| `batch_sampler` | √ | √ | ? | 默认`None` |\n", + "| `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + "| `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", "| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", "| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", "| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "aca72b49", "metadata": { "pycharm": { @@ -103,6 +103,26 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "application/vnd.jupyter.widget-view+json": { @@ -149,14 +169,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n", - "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n", - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n", - "| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n" + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n", + "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n" ] } ], @@ -200,7 +220,9 @@ " \n", "pipe = PipeDemo(tokenizer='bert-base-uncased')\n", "\n", - "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')" + "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n", + "\n", + "print(data_bundle.get_dataset('train'))" ] }, { @@ -214,15 +236,65 @@ "\n", " 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n", "\n", - " 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`" + " 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n", + "\n", + "同时我们看还可以发现,在`fastNLP 0.8`中,**`batch`表示为字典`dict`类型**,**`key`值就是原先数据集中各个字段**\n", + "\n", + " **除去经过`DataBundle.set_ignore`函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "5fd60e42", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", "\n", - "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n", + "from fastNLP import DataSet\n", "\n", - "# dataset = databundle.get_dataset('train')[:6000]\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", - "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", - "# progress_bar=\"tqdm\")\n", - "# dataset.delete_field('sentence')\n", - "# dataset.delete_field('label')\n", - "# dataset.delete_field('idx')\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", "\n", - "# from fastNLP import Vocabulary\n", + "from fastNLP import Vocabulary\n", "\n", - "# vocab = Vocabulary()\n", - "# vocab.from_dataset(dataset, field_name='words')\n", - "# vocab.index_dataset(dataset, field_name='words')\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", "\n", - "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "96380c67", + "metadata": {}, + "source": [ + " 然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**" ] }, { "cell_type": "code", - "execution_count": null, - "id": "50960476", + "execution_count": 3, + "id": "b9dd1273", "metadata": {}, "outputs": [], "source": [ - "# from fastNLP import prepare_torch_dataloader\n", + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "eb75aaba", + "metadata": {}, + "source": [ + "模型使用方面,这里使用`Embedding`、`LSTM`、`MLP`等模块搭建模型,方法类似`pytorch`,结构如下所示\n", "\n", - "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + "```\n", + "ClsByModules(\n", + " (embedding): Embedding(\n", + " (embed): Embedding(8458, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n", + " )\n", + " (mlp): MLP(\n", + " (hiddens): ModuleList()\n", + " (output): Linear(in_features=128, out_features=2, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (loss_fn): CrossEntropyLoss()\n", + ")\n", + "```" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "0b25b25c", "metadata": {}, "outputs": [], "source": [ - "# import torch\n", - "# import torch.nn as nn\n", + "import torch\n", + "import torch.nn as nn\n", "\n", - "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n", - "# from fastNLP import Embedding, CrossEntropyLoss\n", + "from fastNLP.modules.torch import LSTM, MLP\n", + "from fastNLP.embeddings.torch import Embedding\n", "\n", "\n", - "# class ClsByModules(nn.Module):\n", - "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", - "# nn.Module.__init__(self)\n", + "class ClsByModules(nn.Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " nn.Module.__init__(self)\n", "\n", - "# self.embedding = Embedding((vocab_size, embedding_dim))\n", - "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", - "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", + " self.embedding = Embedding((vocab_size, embedding_dim))\n", + " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", + " self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", " \n", - "# self.loss_fn = CrossEntropyLoss()\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", "\n", - "# def forward(self, words):\n", - "# output = self.embedding(words)\n", - "# output, (hidden, cell) = self.lstm(output)\n", - "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", - "# return output\n", + " def forward(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", " \n", - "# def train_step(self, words, target):\n", - "# pred = self(words)\n", - "# return {\"loss\": self.loss_fn(pred, target)}\n", - "\n", - "# def evaluate_step(self, words, target):\n", - "# pred = self(words)\n", - "# pred = torch.max(pred, dim=-1)[1]\n", - "# return {\"pred\": pred, \"target\": target}" + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "markdown", + "id": "4890de5a", + "metadata": {}, + "source": [ + " 接着,初始化模型`model`实例,同时,使用`torch.optim.AdamW`初始化`optimizer`实例" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "9dbbf50d", "metadata": {}, "outputs": [], "source": [ - "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n", + "model = ClsByModules(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", - "# from torch.optim import AdamW\n", + "from torch.optim import AdamW\n", "\n", - "# optimizers = AdamW(params=model.parameters(), lr=5e-5)" + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "id": "054538f5", + "metadata": {}, + "source": [ + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "7a93432f", "metadata": {}, "outputs": [], "source": [ - "# from fastNLP import Trainer, Accuracy\n", + "from fastNLP import Trainer, Accuracy\n", "\n", - "# trainer = Trainer(\n", - "# model=model,\n", - "# driver='torch',\n", - "# device=0, # 'cuda'\n", - "# n_epochs=10,\n", - "# optimizers=optimizers,\n", - "# train_dataloader=train_dataloader,\n", - "# evaluate_dataloaders=evaluate_dataloader,\n", - "# metrics={'acc': Accuracy()}\n", - "# )" + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "31102e0f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.525,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 84.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.54375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 87.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.55,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 88.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 100.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.65,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 104.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.69375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 111.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.66875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 107.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.68125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 109.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# trainer.run(num_eval_batch_per_dl=10)" + "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "8bc4bfb2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# trainer.evaluator.run()" + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "07538876", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b52eafd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "383" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "del dataset\n", + "del sst2data\n", + "\n", + "gc.collect()" ] }, { @@ -255,6 +1050,8 @@ "\n", " 本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", "\n", + "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n", + "\n", "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n", "\n", " 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", @@ -281,95 +1078,7 @@ ")\n", "```\n", "\n", - "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n", - "\n", - " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1aa5cf6d", - "metadata": {}, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "\n", - "sst2data = load_dataset('glue', 'sst2')" - ] - }, - { - "cell_type": "markdown", - "id": "c476abe7", - "metadata": {}, - "source": [ - "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", - "\n", - " **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n", - "\n", - " 并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n", - "\n", - " **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n", - "\n", - " 两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "357ea748", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append('..')\n", - "\n", - "from fastNLP import DataSet\n", - "\n", - "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", - "\n", - "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", - " progress_bar=\"tqdm\")\n", - "dataset.delete_field('sentence')\n", - "dataset.delete_field('label')\n", - "dataset.delete_field('idx')\n", - "\n", - "from fastNLP import Vocabulary\n", - "\n", - "vocab = Vocabulary()\n", - "vocab.from_dataset(dataset, field_name='words')\n", - "vocab.index_dataset(dataset, field_name='words')\n", - "\n", - "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" - ] - }, - { - "cell_type": "markdown", - "id": "96380c67", - "metadata": {}, - "source": [ - "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9dd1273", - "metadata": {}, - "outputs": [], - "source": [ - "from fastNLP import prepare_torch_dataloader\n", - "\n", - "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" - ] - }, - { - "cell_type": "markdown", - "id": "96941b63", - "metadata": {}, - "source": [ - "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n", + "对应到代码上,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`和`optimizer`实例\n", "\n", " 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n", "\n", @@ -378,7 +1087,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "f6e76e2e", "metadata": {}, "outputs": [], @@ -397,12 +1106,12 @@ "id": "0cc5ca10", "metadata": {}, "source": [ - "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "50a13ee5", "metadata": {}, "outputs": [], @@ -423,45 +1132,634 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "28903a7d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[16:21:57] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.654444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 589.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.767778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 691.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.797778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 718.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.803333,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 723.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.807778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 727.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.812222,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 731.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.804444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 724.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.806667,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 726.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "f47a6a35", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "trainer.evaluator.run()" ] }, { "cell_type": "markdown", - "id": "7c811257", + "id": "5b5c0446", "metadata": {}, "source": [ - " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间" + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" ] }, { "cell_type": "code", - "execution_count": null, - "id": "c1a2e2ca", + "execution_count": 14, + "id": "e9e70f88", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "344" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import gc\n", "\n", "del model\n", "del trainer\n", - "del dataset\n", - "del sst2data\n", "\n", "gc.collect()" ] @@ -502,10 +1800,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "03e66686", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "593bc03ed5914953ab94268ff2f01710", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from datasets import load_dataset\n", "\n", @@ -532,10 +1852,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "1f88cad4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -573,7 +1908,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "7802a072", "metadata": {}, "outputs": [], @@ -598,7 +1933,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "4e12c09f", "metadata": {}, "outputs": [], @@ -627,7 +1962,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "cbd6c205", "metadata": {}, "outputs": [], @@ -648,20 +1983,600 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "0f8eff34", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[16:23:41] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.169014,\n", + " \"pre#F1\": 0.170732,\n", + " \"rec#F1\": 0.167331\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.361809,\n", + " \"pre#F1\": 0.312139,\n", + " \"rec#F1\": 0.430279\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.525,\n", + " \"pre#F1\": 0.475728,\n", + " \"rec#F1\": 0.585657\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.627306,\n", + " \"pre#F1\": 0.584192,\n", + " \"rec#F1\": 0.677291\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.710937,\n", + " \"pre#F1\": 0.697318,\n", + " \"rec#F1\": 0.7251\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.739563,\n", + " \"pre#F1\": 0.738095,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.748491,\n", + " \"pre#F1\": 0.756098,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.716763,\n", + " \"pre#F1\": 0.69403,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.768293,\n", + " \"pre#F1\": 0.784232,\n", + " \"rec#F1\": 0.752988\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.757692,\n", + " \"pre#F1\": 0.732342,\n", + " \"rec#F1\": 0.784861\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "37871d6b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "trainer.evaluator.run()" ] diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb index 63410113..3f2bbfa6 100644 --- a/tutorials/fastnlp_tutorial_5.ipynb +++ b/tutorials/fastnlp_tutorial_5.ipynb @@ -296,7 +296,7 @@ "\n", " 在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", "\n", - " 此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n", + " 此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n", "\n", " 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", "\n", @@ -307,10 +307,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "08a872e9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -320,16 +334,16 @@ "class MyMetric(Metric):\n", "\n", " def __init__(self):\n", - " MyMetric.__init__(self)\n", + " Metric.__init__(self)\n", " self.total_num = 0\n", " self.right_num = 0\n", "\n", - " def update(self, pred, true):\n", + " def update(self, pred, target):\n", " self.total_num += target.size(0)\n", " self.right_num += target.eq(pred).sum().item()\n", "\n", " def get_metric(self, reset=True):\n", - " acc = self.acc_count / self.total_num\n", + " acc = self.right_num / self.total_num\n", " if reset:\n", " self.total_num = 0\n", " self.right_num = 0\n", @@ -346,14 +360,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "5ad81ac7", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from datasets import load_dataset\n", "\n", @@ -365,30 +401,43 @@ "id": "e9d81760", "metadata": {}, "source": [ - " 在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n", + " 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n", "\n", - " 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`" + " 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "cfb28b1b", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from fastNLP import DataSet\n", "\n", "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", - "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'true': ins['label']}, \n", - " progress_bar=\"tqdm\")\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n", "dataset.delete_field('sentence')\n", - "dataset.delete_field('label')\n", "dataset.delete_field('idx')\n", "\n", "from fastNLP import Vocabulary\n", @@ -415,7 +464,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "2fd210c5", "metadata": {}, "outputs": [], @@ -445,10 +494,10 @@ "|
[09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.6875\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.825\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] diff --git a/tutorials/fastnlp_tutorial_6.ipynb b/tutorials/fastnlp_tutorial_6.ipynb index 552f73d9..63f7481e 100644 --- a/tutorials/fastnlp_tutorial_6.ipynb +++ b/tutorials/fastnlp_tutorial_6.ipynb @@ -19,15 +19,37 @@ "\n", " 2.2 使用 jittor 搭建并训练模型\n", "\n", - " 3 fastNLP 实现 paddle 与 pytorch 互转" + "" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "08752c5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from datasets import load_dataset\n", "\n", @@ -36,10 +58,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "7e8cc210", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 21:01:38.510813 72 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
[21:03:08] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.78125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.80625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 129.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "trainer.run(num_eval_batch_per_dl=10) # 然后卡了?" + "trainer.run(num_eval_batch_per_dl=10) " ] }, { @@ -214,7 +922,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "c600191d", "metadata": {}, "outputs": [], @@ -231,37 +939,61 @@ " self.hidden_dim = hidden_dim\n", "\n", " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", - " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", - " self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", " nn.ReLU(),\n", - " nn.Linear(hidden_dim * 2, output_dim)])\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", "\n", - " self.loss_fn = nn.BCELoss()\n", + " self.loss_fn = nn.MSELoss()\n", "\n", " def execute(self, words):\n", " output = self.embedding(words)\n", " output, (hidden, cell) = self.lstm(output)\n", - " # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n", - " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", " return output\n", " \n", " def train_step(self, words, target):\n", " pred = self(words)\n", - " return {\"loss\": self.loss_fn(pred, target)}\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", "\n", " def evaluate_step(self, words, target):\n", " pred = self(words)\n", - " pred = jittor.max(pred, axis=-1)[1]\n", - " return {\"pred\": pred, \"target\": target}" + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "a94ed8c4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", @@ -270,34 +1002,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "6d15ebc1", "metadata": {}, "outputs": [], "source": [ "from jittor.optim import AdamW\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "95d8d09e", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_jittor_dataloader\n", "\n", - "# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", "\n", - "dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "917eab81", "metadata": {}, "outputs": [], @@ -307,24 +1039,587 @@ "trainer = Trainer(\n", " model=model,\n", " driver='jittor',\n", - " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", - " train_dataloader=dl_bundle['train'], # train_dataloader,\n", - " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "f7c4ac5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta: 0s \n", + "\n", + "Compiling Operators(31/31) used: 7.31s eta: 0s \n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.61875,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 99\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 112\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.725,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 116\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.73125,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 117\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tutorials/fastnlp_tutorial_7.ipynb b/tutorials/fastnlp_tutorial_7.ipynb deleted file mode 100644 index 0a7d6922..00000000 --- a/tutorials/fastnlp_tutorial_7.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fdd7ff16", - "metadata": {}, - "source": [ - "# T7. callback 自定义训练过程\n", - "\n", - " 1 \n", - " \n", - " 1.1 \n", - "\n", - " 1.2 \n", - "\n", - " 2 \n", - "\n", - " 2.1 \n", - "\n", - " 2.2 \n", - "\n", - " 3 \n", - "\n", - " 3.1 \n", - "\n", - " 3.2 " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08752c5a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/fastnlp_tutorial_8.ipynb b/tutorials/fastnlp_tutorial_8.ipynb deleted file mode 100644 index 0664bc41..00000000 --- a/tutorials/fastnlp_tutorial_8.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fdd7ff16", - "metadata": {}, - "source": [ - "# T8. fastNLP 中的文件读取模块\n", - "\n", - " 1 fastNLP 中的 EmbedLoader 模块\n", - " \n", - " 1.1 \n", - "\n", - " 1.2 \n", - "\n", - " 2 fastNLP 中的 Loader 模块\n", - "\n", - " 2.1 \n", - "\n", - " 2.2 \n", - "\n", - " 3 fastNLP 中的 Pipe 模块\n", - "\n", - " 3.1 \n", - "\n", - " 3.2 " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08752c5a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}