diff --git a/docs/source/tutorials/fastnlp_tutorial_0.ipynb b/docs/source/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..09667794 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + "  1   trainer 和 evaluator 的基本关系\n", + " \n", + "    1.1   trainer 和 evaluater 的初始化\n", + "\n", + "    1.2   driver 的含义与使用要求\n", + "\n", + "    1.3   trainer 内部初始化 evaluater\n", + "\n", + "  2   使用 fastNLP 搭建 argmax 模型\n", + "\n", + "    2.1   trainer_step 和 evaluator_step\n", + "\n", + "    2.2   trainer 和 evaluator 的参数匹配\n", + "\n", + "    2.3   示例:argmax 模型的搭建\n", + "\n", + "  3   使用 fastNLP 训练 argmax 模型\n", + " \n", + "    3.1   trainer 外部初始化的 evaluator\n", + "\n", + "    3.2   trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 1.0`中,`Trainer`模块和`Evaluator`模块分别表示 **“训练器”和“评测器”**\n", + "\n", + "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 1.0`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + "  非常关键的问题在于**如何正确设置二者的 driver**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 1.0`中,**driver**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 1.0`中,**Trainer 和 Evaluator 都依赖于具体的 driver 来完成整体的工作流程**\n", + "\n", + "  具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", + "\n", + "注:这里给出一条建议:**在同一脚本中**,**所有的** Trainer **和** Evaluator **使用的** driver **应当保持一致**\n", + "\n", + "  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n", + "\n", + "  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 1.0`中,如果在**初始化 Trainer 时**,**传入参数 evaluator_dataloaders 和 metrics **\n", + "\n", + "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. argmax 模型的搭建实例" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 trainer_step 和 evaluator_step\n", + "\n", + "在`fastNLP 1.0`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", + "\n", + "  添加`pytorch`要求的`forward`方法外,还需要添加 `train_step` 和 `evaluate_step` 这两个方法\n", + "\n", + "```python\n", + "class Model(torch.nn.Module):\n", + " def __init__(self):\n", + " super(Model, self).__init__()\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + " pass\n", + "\n", + " def forward(self, x):\n", + " pass\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}\n", + "```\n", + "***\n", + "在`fastNLP 1.0`中,**函数 train_step 是 Trainer 中参数 train_fn 的默认值**\n", + "\n", + "  由于,在`Trainer`训练时,**Trainer 通过参数 train_fn 对应的模型方法获得当前数据批次的损失值**\n", + "\n", + "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 1.0`中,**Trainer 要求模型通过 train_step 来返回一个字典**,**满足如 {\"loss\": loss} 的形式**\n", + "\n", + "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n", + "\n", + "同样,在`fastNLP 1.0`中,**函数 evaluate_step 是 Evaluator 中参数 evaluate_fn 的默认值**\n", + "\n", + "  在`Evaluator`测试时,**Evaluator 通过参数 evaluate_fn 对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + "在`fastNLP 1.0`中,参数匹配涉及到两个方面,分别是在\n", + "\n", + "  一方面,**在模型的前向传播中**,**dataloader 向 train_step 或 evaluate_step 函数传递 batch**\n", + "\n", + "  另方面,**在模型的评测过程中**,**evaluate_dataloader 向 metric 的 update 函数传递 batch**\n", + "\n", + "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n", + "\n", + "    **fastNLP 1.0 要求 dataloader 生成的每个 batch **,**满足如 {\"x\": x, \"y\": y} 的形式**\n", + "\n", + "  同时,`fastNLP 1.0`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n", + "\n", + "    **字典形式的定义**,**对应在 Dataset 定义的 \\_\\_getitem\\_\\_ 方法中**,例如下方的`ArgMaxDatset`\n", + "\n", + "  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", + "\n", + "    `fastNLP 1.0`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", + "\n", + "```python\n", + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return len(self.x)\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f5f1a6aa", + "metadata": {}, + "source": [ + "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n", + "\n", + "    **update 函数**,**针对一个 batch 的预测结果**,计算其累计的评价指标\n", + "\n", + "    **get_metric 函数**,**统计 update 函数累计的评价指标**,来计算最终的评价结果\n", + "\n", + "  例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n", + "\n", + "    而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n", + "\n", + "  在此基础上,**fastNLP 1.0 要求 evaluate_dataloader 生成的每个 batch 传递给对应的 metric**\n", + "\n", + "    **以 {\"pred\": y_pred, \"target\": y_true} 的形式**,对应其`update`函数的函数签名\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 示例:argmax 模型的搭建\n", + "\n", + "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + "  首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " pred = self.ac1(self.linear1(x))\n", + " pred = self.ac2(self.linear2(pred))\n", + " pred = self.output(pred)\n", + " return pred\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "71f3fa6b", + "metadata": {}, + "source": [ + "  接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n", + "\n", + "    数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n", + "\n", + "    数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDataset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + "  然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n", + "\n", + "    再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n", + "\n", + "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + "  此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "363b5b09", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "c8d4443f", + "metadata": {}, + "source": [ + "  最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 fastNLP 1.0 训练 argmax 模型\n", + "\n", + "### 3.1 trainer 外部初始化的 evaluator" + ] + }, + { + "cell_type": "markdown", + "id": "55145553", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n", + "\n", + "  需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n", + "\n", + "  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", + "\n", + "    但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", + "\n", + "  通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 设定迭代轮数 \n", + " progress_bar=\"auto\" # 设定进度条格式\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6e202d6e", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + "  其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", + "\n", + "  `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ba047ead", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n", + "\n", + "  需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n", + "\n", + "  需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n", + "\n", + "  类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8157bb9b", + "metadata": {}, + "source": [ + "通过使用`Evaluator`类的`run`函数,进行训练\n", + "\n", + "  其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", + "\n", + "  最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "### 3.2 trainer 内部初始化的 evaluator \n", + "\n", + "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", + "\n", + "  通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", + "\n", + "  但是中间的评估结果仍会保留;**通过 evaluate_every 设定评估频率**,可以为负数、正数或者函数:\n", + "\n", + "    **为负数时**,**表示每隔几个 epoch 评估一次**;**为正数时**,**则表示每隔几个 batch 评估一次**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, \n", + " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + "  还可以通过**参数 num_eval_sanity_batch 决定每次训练前运行多少个 evaluate_batch 进行评测**,**默认为 2 **\n", + "\n", + "  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[18:28:25] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.31,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 31.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.33,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 33.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.34,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 34.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.37,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 37.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.4,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 40.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4e9c619", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1bc7cb4a",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_1.ipynb b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
new file mode 100644
index 00000000..cff81a21
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
@@ -0,0 +1,1333 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "cdc25fcd",
+   "metadata": {},
+   "source": [
+    "# T1. dataset 和 vocabulary 的基本使用\n",
+    "\n",
+    "  1   dataset 的使用与结构\n",
+    " \n",
+    "    1.1   dataset 的结构与创建\n",
+    "\n",
+    "    1.2   dataset 的数据预处理\n",
+    "\n",
+    "    1.3   延伸:instance 和 field\n",
+    "\n",
+    "  2   vocabulary 的结构与使用\n",
+    "\n",
+    "    2.1   vocabulary 的创建与修改\n",
+    "\n",
+    "    2.2   vocabulary 与 OOV 问题\n",
+    "\n",
+    "  3   dataset 和 vocabulary 的组合使用\n",
+    " \n",
+    "    3.1   从 dataframe 中加载 dataset\n",
+    "\n",
+    "    3.2   从 dataset 中获取 vocabulary"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0eb18a22",
+   "metadata": {},
+   "source": [
+    "## 1. dataset 的基本使用\n",
+    "\n",
+    "### 1.1  dataset 的结构与创建\n",
+    "\n",
+    "在`fastNLP 1.0`中,使用`DataSet`模块表示数据集,**dataset 类似于关系型数据库中的数据表**(下文统一为小写 `dataset`)\n",
+    "\n",
+    "  **主要包含 field 字段和 instance 实例两个元素**,对应 table 中的 field 字段和`record`记录\n",
+    "\n",
+    "在`fastNLP 1.0`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n",
+    "\n",
+    "  初始化方法,即将字典形式的表格 **{'field1': column1, 'field2': column2, ...}** 传入构造函数"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "a1d69ad2",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n", + " 'words': [['This', 'is', 'an', 'apple', '.'], \n", + " ['I', 'like', 'apples', '.'], \n", + " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n", + " 'num': [5, 4, 7]}\n", + "\n", + "dataset = DataSet(data)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9260fdc6", + "metadata": {}, + "source": [ + "  在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d72ef00", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------+--------------------+------------------------+------+\n", + "| 序号 | 句子 | 字符 | 长度 |\n", + "+------+--------------------+------------------------+------+\n", + "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n", + "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n", + "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n", + "+------+--------------------+------------------------+------+\n" + ] + } + ], + "source": [ + "temp = {'序号': [0, 1, 2], \n", + " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n", + " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n", + " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n", + " ['才', '能', '到', '达', '彼', '岸', '。']],\n", + " '长度': [7, 9, 7]}\n", + "\n", + "chinese = DataSet(temp)\n", + "print(chinese)" + ] + }, + { + "cell_type": "markdown", + "id": "202e5490", + "metadata": {}, + "source": [ + "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n", + "\n", + "  注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "09b478f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2492313174344 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "aa277674", + "metadata": {}, + "source": [ + "  注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n", + "\n", + "    如下所示,**dropped 和 dataset 具有相同 id**,**对 dropped 执行删除操作 dataset 同时会被修改**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77c8583a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2491986424200 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped.drop(lambda ins:ins['num'] < 5)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a76199dc", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8824b40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+--------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "+-----+--------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.delete_instance(2)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "f4fa9f33", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f68ddb40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+--------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "+-----+--------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset.delete_field('num')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "b1e9d42c", + "metadata": {}, + "source": [ + "### 1.2 dataset 的数据预处理\n", + "\n", + "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", + "\n", + "  **apply 和 apply_more 输入整条实例**,**apply_field 和 apply_field_more 仅输入实例的部分字段**\n", + "\n", + "  **apply 和 apply_field 仅输出单个字段**,**apply_more 和 apply_field_more 则是输出多个字段**\n", + "\n", + "  **apply 和 apply_field 返回的是个列表**,**apply_more 和 apply_field_more 返回的是个字典**\n", + "\n", + "    预处理过程中,通过`progress_bar`参数设置显示进度条类型,通过`num_proc`设置多进程\n", + "***\n", + "\n", + "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n", + "\n", + "  的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72a0b5f9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00,\n", + " 'words': ,\n", + " 'num': }" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_all_fields()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "5433815c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['num', 'sentence', 'words']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_field_names()" + ] + }, + { + "cell_type": "markdown", + "id": "4964eeed", + "metadata": {}, + "source": [ + "其他`dataset`的基本使用:通过`in`或者`has_field`方法可以判断`dataset`的是否包含某种字段\n", + "\n", + "  通过`rename_field`方法可以更改`dataset`中的字段名称;通过`concat`方法可以实现两个`dataset`中的拼接\n", + "\n", + "  通过`len`可以统计`dataset`中的实例数目;`dataset`的全部变量与函数可以通过`dir(dataset)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "25ce5488", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3 False\n", + "6 True\n", + "+------------------------------+------------------------------+--------+\n", + "| sentence | words | length |\n", + "+------------------------------+------------------------------+--------+\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "+------------------------------+------------------------------+--------+\n" + ] + } + ], + "source": [ + "print(len(dataset), dataset.has_field('length')) \n", + "if 'num' in dataset:\n", + " dataset.rename_field('num', 'length')\n", + "elif 'length' in dataset:\n", + " dataset.rename_field('length', 'num')\n", + "dataset.concat(dataset)\n", + "print(len(dataset), dataset.has_field('length')) \n", + "print(dataset) " + ] + }, + { + "cell_type": "markdown", + "id": "e30a6cd7", + "metadata": {}, + "source": [ + "## 2. vocabulary 的结构与使用\n", + "\n", + "### 2.1 vocabulary 的创建与修改\n", + "\n", + "在`fastNLP 1.0`中,使用`Vocabulary`模块表示词汇表,**vocabulary 的核心是从单词到序号的映射**\n", + "\n", + "  可以直接通过构造函数实例化,通过查找`word2idx`属性,可以找到`vocabulary`映射对应的字典实现\n", + "\n", + "  **默认补零 padding 用 \\ 表示**,**对应序号为0**;**未知单词 unknown 用 \\ 表示**,**对应序号1**\n", + "\n", + "  通过打印`vocabulary`可以看到词汇表中的单词列表,其中,`padding`和`unknown`不会显示" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3515e096", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary([]...)\n", + "{'': 0, '': 1}\n", + " 0\n", + " 1\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "print(vocab)\n", + "print(vocab.word2idx)\n", + "print(vocab.padding, vocab.padding_idx)\n", + "print(vocab.unknown, vocab.unknown_idx)" + ] + }, + { + "cell_type": "markdown", + "id": "640be126", + "metadata": {}, + "source": [ + "在`vocabulary`中,通过`add_word`方法或`add_word_lst`方法,可以单独或批量添加单词\n", + "\n", + "  通过`len`或`word_count`属性,可以显示`vocabulary`的单词量和每个单词添加的次数" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "88c7472a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5 Counter({'生活': 1, '就像': 1, '海洋': 1})\n", + "6 Counter({'生活': 1, '就像': 1, '海洋': 1, '只有': 1})\n", + "6 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5}\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋'])\n", + "print(len(vocab), vocab.word_count)\n", + "vocab.add_word('只有')\n", + "print(len(vocab), vocab.word_count)\n", + "print(len(vocab), vocab.word2idx)" + ] + }, + { + "cell_type": "markdown", + "id": "f9ec8b28", + "metadata": {}, + "source": [ + "  **通过 to_word 方法可以找到单词对应的序号**,**通过 to_index 方法可以找到序号对应的单词**\n", + "\n", + "    由于序号0和序号1已经被占用,所以**新加入的词的序号从2开始计数**,如`'生活'`对应2\n", + "\n", + "    通过`has_word`方法可以判断单词是否在词汇表中,没有的单词被判做``" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "3447acde", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 0\n", + " 1\n", + "生活 2\n", + "彼岸 1 False\n" + ] + } + ], + "source": [ + "print(vocab.to_word(0), vocab.to_index(''))\n", + "print(vocab.to_word(1), vocab.to_index(''))\n", + "print(vocab.to_word(2), vocab.to_index('生活'))\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))" + ] + }, + { + "cell_type": "markdown", + "id": "b4e36850", + "metadata": {}, + "source": [ + "**vocabulary 允许反复添加相同单词**,**可以通过 word_count 方法看到相应单词被添加的次数**\n", + "\n", + "  但其中没有``和``,`vocabulary`的全部变量与函数可以通过`dir(vocabulary)`查询\n", + "\n", + "  注:**使用 add_word_lst 添加单词**,**单词对应序号不会动态调整**,**使用 dataset 添加单词的情况不同**" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "490b101c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "生活 2\n", + "彼岸 12 True\n", + "13 Counter({'人': 4, '生活': 2, '就像': 2, '海洋': 2, '只有': 2, '意志': 1, '坚强的': 1, '才': 1, '能': 1, '到达': 1, '彼岸': 1})\n", + "13 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5, '人': 6, '意志': 7, '坚强的': 8, '才': 9, '能': 10, '到达': 11, '彼岸': 12}\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋', '只有', '意志', '坚强的', '人', '人', '人', '人', '才', '能', '到达', '彼岸'])\n", + "print(vocab.to_word(2), vocab.to_index('生活'))\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))\n", + "print(len(vocab), vocab.word_count)\n", + "print(len(vocab), vocab.word2idx)" + ] + }, + { + "cell_type": "markdown", + "id": "23e32a63", + "metadata": {}, + "source": [ + "### 2.2 vocabulary 与 OOV 问题\n", + "\n", + "在`vocabulary`模块初始化的时候,可以通过指定`unknown`和`padding`为`None`,限制其存在\n", + "\n", + "  此时添加单词直接从0开始标号,如果遇到未知单词会直接报错,即 out of vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "a99ff909", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'positive': 0, 'negative': 1}\n", + "ValueError: word `neutral` not in vocabulary\n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown=None, padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "try:\n", + " print(vocab.to_index('neutral'))\n", + "except ValueError:\n", + " print(\"ValueError: word `neutral` not in vocabulary\")" + ] + }, + { + "cell_type": "markdown", + "id": "618da6bd", + "metadata": {}, + "source": [ + "  相应的,如果只指定其中的`unknown`,则编号会后移一个,同时遇到未知单词全部当做``" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "432f74c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'': 0, 'positive': 1, 'negative': 2}\n", + "0 \n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown='', padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "print(vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral')))" + ] + }, + { + "cell_type": "markdown", + "id": "b6263f73", + "metadata": {}, + "source": [ + "## 3 dataset 和 vocabulary 的组合使用\n", + " \n", + "### 3.1 从 dataframe 中加载 dataset\n", + "\n", + "以下通过 [NLP-beginner](https://github.com/FudanNLP/nlp-beginner) 实践一中 [Rotten Tomatoes 影评数据集](https://www.kaggle.com/c/sentiment-analysis-on-movie-reviews) 的部分训练数据组成`test4dataset.tsv`文件\n", + "\n", + "  介绍如何使用`dataset`、`vocabulary`简单加载并处理数据集,首先使用`pandas`模块,读取原始数据的`dataframe`" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3dbd985d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SentenceIdSentenceSentiment
01A series of escapades demonstrating the adage ...negative
12This quiet , introspective and entertaining in...positive
23Even fans of Ismail Merchant 's work , I suspe...negative
34A positively thrilling combination of ethnogra...neutral
45A comedy-drama of nearly epic proportions root...positive
56The Importance of Being Earnest , so thick wit...neutral
\n", + "
" + ], + "text/plain": [ + " SentenceId Sentence Sentiment\n", + "0 1 A series of escapades demonstrating the adage ... negative\n", + "1 2 This quiet , introspective and entertaining in... positive\n", + "2 3 Even fans of Ismail Merchant 's work , I suspe... negative\n", + "3 4 A positively thrilling combination of ethnogra... neutral\n", + "4 5 A comedy-drama of nearly epic proportions root... positive\n", + "5 6 The Importance of Being Earnest , so thick wit... neutral" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv('./data/test4dataset.tsv', sep='\\t')\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "919ab350", + "metadata": {}, + "source": [ + "接着,通过`dataset`中的`from_pandas`方法填充数据集,并使用`apply_more`方法对文本进行分词操作" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "4f634586", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00': 0, '': 1, 'a': 2, 'of': 3, ',': 4, 'the': 5, '.': 6, 'is': 7, 'and': 8, 'good': 9, 'for': 10, 'which': 11, 'this': 12, \"'s\": 13, 'series': 14, 'escapades': 15, 'demonstrating': 16, 'adage': 17, 'that': 18, 'what': 19, 'goose': 20, 'also': 21, 'gander': 22, 'some': 23, 'occasionally': 24, 'amuses': 25, 'but': 26, 'none': 27, 'amounts': 28, 'to': 29, 'much': 30, 'story': 31, 'quiet': 32, 'introspective': 33, 'entertaining': 34, 'independent': 35, 'worth': 36, 'seeking': 37, 'even': 38, 'fans': 39, 'ismail': 40, 'merchant': 41, 'work': 42, 'i': 43, 'suspect': 44, 'would': 45, 'have': 46, 'hard': 47, 'time': 48, 'sitting': 49, 'through': 50, 'one': 51, 'positively': 52, 'thrilling': 53, 'combination': 54, 'ethnography': 55, 'all': 56, 'intrigue': 57, 'betrayal': 58, 'deceit': 59, 'murder': 60, 'shakespearean': 61, 'tragedy': 62, 'or': 63, 'juicy': 64, 'soap': 65, 'opera': 66, 'comedy-drama': 67, 'nearly': 68, 'epic': 69, 'proportions': 70, 'rooted': 71, 'in': 72, 'sincere': 73, 'performance': 74, 'by': 75, 'title': 76, 'character': 77, 'undergoing': 78, 'midlife': 79, 'crisis': 80, 'importance': 81, 'being': 82, 'earnest': 83, 'so': 84, 'thick': 85, 'with': 86, 'wit': 87, 'it': 88, 'plays': 89, 'like': 90, 'reading': 91, 'from': 92, 'bartlett': 93, 'familiar': 94, 'quotations': 95} \n", + "\n", + "Vocabulary(['a', 'series', 'of', 'escapades', 'demonstrating']...)\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab = vocab.from_dataset(dataset, field_name='Sentence')\n", + "print(vocab.word_count, '\\n')\n", + "print(vocab.word2idx, '\\n')\n", + "print(vocab)" + ] + }, + { + "cell_type": "markdown", + "id": "f0857ccb", + "metadata": {}, + "source": [ + "之后,**通过 vocabulary 的 index_dataset 方法**,**调整 dataset 中指定字段的元素**,**使用编号将之代替**\n", + "\n", + "  使用上述方法,可以将影评数据集中的单词序列转化为词编号序列,为接下来转化为词嵌入序列做准备" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2f9a04b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------+------------------------------+-----------+\n", + "| SentenceId | Sentence | Sentiment |\n", + "+------------+------------------------------+-----------+\n", + "| 1 | [2, 14, 3, 15, 16, 5, 17,... | negative |\n", + "| 2 | [12, 32, 4, 33, 8, 34, 35... | positive |\n", + "| 3 | [38, 39, 3, 40, 41, 13, 4... | negative |\n", + "| 4 | [2, 52, 53, 54, 3, 55, 8,... | neutral |\n", + "| 5 | [2, 67, 3, 68, 69, 70, 71... | positive |\n", + "| 6 | [5, 81, 3, 82, 83, 4, 84,... | neutral |\n", + "+------------+------------------------------+-----------+\n" + ] + } + ], + "source": [ + "vocab.index_dataset(dataset, field_name='Sentence')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "6b26b707", + "metadata": {}, + "source": [ + "最后,使用相同方法,再将`dataset`中`Sentiment`字段中的`negative`、`neutral`、`positive`转化为数字编号" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "5f5eed18", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'negative': 0, 'positive': 1, 'neutral': 2}\n", + "+------------+------------------------------+-----------+\n", + "| SentenceId | Sentence | Sentiment |\n", + "+------------+------------------------------+-----------+\n", + "| 1 | [2, 14, 3, 15, 16, 5, 17,... | 0 |\n", + "| 2 | [12, 32, 4, 33, 8, 34, 35... | 1 |\n", + "| 3 | [38, 39, 3, 40, 41, 13, 4... | 0 |\n", + "| 4 | [2, 52, 53, 54, 3, 55, 8,... | 2 |\n", + "| 5 | [2, 67, 3, 68, 69, 70, 71... | 1 |\n", + "| 6 | [5, 81, 3, 82, 83, 4, 84,... | 2 |\n", + "+------------+------------------------------+-----------+\n" + ] + } + ], + "source": [ + "target_vocab = Vocabulary(padding=None, unknown=None)\n", + "\n", + "target_vocab.from_dataset(dataset, field_name='Sentiment')\n", + "print(target_vocab.word2idx)\n", + "target_vocab.index_dataset(dataset, field_name='Sentiment')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "eed7ea64", + "metadata": {}, + "source": [ + "在最后的最后,通过以下的一张图,来总结本章关于`dataset`和`vocabulary`主要知识点的讲解,以及两者的联系\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35b4f0f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_2.ipynb b/docs/source/tutorials/fastnlp_tutorial_2.ipynb new file mode 100644 index 00000000..546e471d --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_2.ipynb @@ -0,0 +1,884 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# T2. databundle 和 tokenizer 的基本使用\n", + "\n", + "  1   fastNLP 中 dataset 的延伸\n", + "\n", + "    1.1   databundle 的概念与使用\n", + "\n", + "  2   fastNLP 中的 tokenizer\n", + " \n", + "    2.1   PreTrainedTokenizer 的概念\n", + "\n", + "    2.2   BertTokenizer 的基本使用\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. fastNLP 中 dataset 的延伸\n", + "\n", + "### 1.1 databundle 的概念与使用\n", + "\n", + "在`fastNLP 1.0`中,在常用的数据加载模块`DataLoader`和数据集`DataSet`模块之间,还存在\n", + "\n", + "  一个中间模块,即 **数据包 DataBundle 模块**,可以从`fastNLP.io`路径中导入该模块\n", + "\n", + "在`fastNLP 1.0`中,**一个 databundle 数据包包含若干 dataset 数据集和 vocabulary 词汇表**\n", + "\n", + "  分别存储在`datasets`和`vocabs`两个变量中,所以了解`databundle`数据包之前\n", + "\n", + "需要首先**复习 dataset 数据集和 vocabulary 词汇表**,**下面的一串代码**,**你知道其大概含义吗?**\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00': 0, '': 1, 'negative': 2, 'positive': 3, 'neutral': 4}\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "from fastNLP import DataSet\n", + "from fastNLP import Vocabulary\n", + "from fastNLP.io import DataBundle\n", + "\n", + "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n", + "datasets.rename_field('Sentence', 'text')\n", + "datasets.rename_field('Sentiment', 'label')\n", + "datasets.apply_more(lambda ins:{'label': ins['label'].lower(), \n", + " 'text': ins['text'].lower().split()},\n", + " progress_bar='tqdm')\n", + "datasets.delete_field('SentenceId')\n", + "train_ds, test_ds = datasets.split(ratio=0.7)\n", + "datasets = {'train': train_ds, 'test': test_ds}\n", + "print(datasets['train'])\n", + "print(datasets['test'])\n", + "\n", + "vocabs = {}\n", + "vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n", + "vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n", + "print(vocabs['label'].word2idx)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "上述代码的含义是:从`test4dataset`的 6 条数据中,划分 4 条训练集(`int(6*0.7) = 4`),2 条测试集\n", + "\n", + "    修改相关字段名称,删除序号字段,同时将标签都设为小写,对文本进行分词\n", + "\n", + "  接着通过`concat`方法拼接测试集训练集,注意设置`inplace=False`,生成临时的新数据集\n", + "\n", + "  使用`from_dataset`方法从拼接的数据集中抽取词汇表,为将数据集中的单词替换为序号做准备\n", + "\n", + "由此就可以得到**数据集字典 datasets**(**对应训练集、测试集**)和**词汇表字典 vocabs**(**对应数据集各字段**)\n", + "\n", + "  然后就可以初始化`databundle`了,通过`print`可以观察其大致结构,效果如下" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 2 datasets:\n", + "\ttrain has 4 instances.\n", + "\ttest has 2 instances.\n", + "In total 2 vocabs:\n", + "\tlabel has 5 entries.\n", + "\ttext has 96 entries.\n", + "\n", + "['train', 'test']\n", + "['label', 'text']\n" + ] + } + ], + "source": [ + "data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n", + "print(data_bundle)\n", + "print(data_bundle.get_dataset_names())\n", + "print(data_bundle.get_vocab_names())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此外,也可以通过`data_bundle`的`num_dataset`和`num_vocab`返回数据表和词汇表个数\n", + "\n", + "  通过`data_bundle`的`iter_datasets`和`iter_vocabs`遍历数据表和词汇表" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 2 datasets:\n", + "\ttrain has 4 instances.\n", + "\ttest has 2 instances.\n", + "In total 2 datasets:\n", + "\tlabel has 5 entries.\n", + "\ttext has 96 entries.\n" + ] + } + ], + "source": [ + "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n", + "for name, dataset in data_bundle.iter_datasets():\n", + " print(\"\\t%s has %d instances.\" % (name, len(dataset)))\n", + "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n", + "for name, vocab in data_bundle.iter_vocabs():\n", + " print(\"\\t%s has %d entries.\" % (name, len(vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在数据包`databundle`中,也有和数据集`dataset`类似的四个`apply`函数,即\n", + "\n", + "  `apply`函数、`apply_field`函数、`apply_field_more`函数和`apply_more`函数\n", + "\n", + "  负责对数据集进行预处理,如下所示是`apply_more`函数的示例,其他函数类似\n", + "\n", + "此外,通过`get_dataset`函数,可以通过数据表名`name`称找到对应数据表\n", + "\n", + "  通过`get_vocab`函数,可以通过词汇表名`field_name`称找到对应词汇表" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + "在`fastNLP 1.0`中,**使用 PreTrainedTokenizer 模块来为数据集中的词语进行词向量的标注**\n", + "\n", + "  需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了 transformers 模块**\n", + "\n", + "  这是因为 `fastNLP 1.0`中`PreTrainedTokenizer`模块的实现基于`Huggingface Transformers`库\n", + "\n", + "**Huggingface Transformers 是一个开源的**,**基于 transformer 模型结构提供的预训练语言库**\n", + "\n", + "  包含了多种经典的基于`transformer`的预训练模型,如`BERT`、`BART`、`RoBERTa`、`GPT2`、`CPT`\n", + "\n", + "  更多相关内容可以参考`Huggingface Transformers`的[相关论文](https://arxiv.org/pdf/1910.03771.pdf)、[官方文档](https://huggingface.co/transformers/)以及[的代码仓库](https://github.com/huggingface/transformers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 BertTokenizer 的基本使用\n", + "\n", + "在`fastNLP 1.0`中,以`PreTrainedTokenizer`为基类,泛化出多个子类,实现基于`BERT`等模型的标注\n", + "\n", + "  本节以`BertTokenizer`模块为例,展示`PreTrainedTokenizer`模块的使用方法与应用实例\n", + "\n", + "**BertTokenizer 的初始化包括 导入模块和导入数据 两步**,先通过从`fastNLP.transformers.torch`中\n", + "\n", + "  导入`BertTokenizer`模块,再**通过 from_pretrained 方法指定 tokenizer 参数类型下载**\n", + "\n", + "  其中,**'bert-base-uncased' 指定 tokenizer 使用的预训练 BERT 类型**:单词不区分大小写\n", + "\n", + "    **模块层数 L=12**,**隐藏层维度 H=768**,**自注意力头数 A=12**,**总参数量 110M**\n", + "\n", + "  另外,模型参数自动下载至 home 目录下的`~\\.cache\\huggingface\\transformers`文件夹中" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "from fastNLP.transformers.torch import BertTokenizer\n", + "\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过变量`vocab_size`和`vocab_files_names`可以查看`BertTokenizer`的词汇表的大小和对应文件\n", + "\n", + "  通过变量`vocab`可以访问`BertTokenizer`预训练的词汇表(由于内容过大就不演示了" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "30522 {'vocab_file': 'vocab.txt'}\n" + ] + } + ], + "source": [ + "print(tokenizer.vocab_size, tokenizer.vocab_files_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过变量`all_special_tokens`或通过变量`special_tokens_map`可以**查看 BertTokenizer 内置的特殊词素**\n", + "\n", + "  包括**未知符 '[UNK]'**, **断句符 '[SEP]'**, **补零符 '[PAD]'**, **分类符 '[CLS]'**, **掩码 '[MASK]'**\n", + "\n", + "通过变量`all_special_ids`可以**查看 BertTokenizer 内置的特殊词素对应的词汇表编号**,相同功能\n", + "\n", + "  也可以直接通过查看`pad_token`,值为`'[UNK]'`,和`pad_token_id`,值为`0`,等变量来实现" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pad_token [PAD] 0\n", + "unk_token [UNK] 100\n", + "cls_token [CLS] 101\n", + "sep_token [SEP] 102\n", + "msk_token [MASK] 103\n", + "all_tokens ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 102, 0, 101, 103]\n", + "{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n" + ] + } + ], + "source": [ + "print('pad_token', tokenizer.pad_token, tokenizer.pad_token_id) \n", + "print('unk_token', tokenizer.unk_token, tokenizer.unk_token_id) \n", + "print('cls_token', tokenizer.cls_token, tokenizer.cls_token_id) \n", + "print('sep_token', tokenizer.sep_token, tokenizer.sep_token_id)\n", + "print('msk_token', tokenizer.mask_token, tokenizer.mask_token_id)\n", + "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n", + "print(tokenizer.special_tokens_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此外,还可以添加其他特殊字符,例如起始符`[BOS]`、终止符`[EOS]`,添加后词汇表编号也会相应改变\n", + "\n", + "  *但是如何添加这两个之外的字符,并且如何将这两个的编号设置为 [UNK] 之外的编号???*" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bos_token [BOS] 100\n", + "eos_token [EOS] 100\n", + "all_tokens ['[BOS]', '[EOS]', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 100, 100, 102, 0, 101, 103]\n", + "{'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n" + ] + } + ], + "source": [ + "tokenizer.bos_token = '[BOS]'\n", + "tokenizer.eos_token = '[EOS]'\n", + "# tokenizer.bos_token_id = 104\n", + "# tokenizer.eos_token_id = 105\n", + "print('bos_token', tokenizer.bos_token, tokenizer.bos_token_id)\n", + "print('eos_token', tokenizer.eos_token, tokenizer.eos_token_id)\n", + "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n", + "print(tokenizer.special_tokens_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`BertTokenizer`中,**使用 tokenize 函数和 convert_tokens_to_string 函数可以实现文本和词素列表的互转**\n", + "\n", + "  此外,**使用 convert_tokens_to_ids 函数和 convert_ids_to_tokens 函数则可以实现词素和词素编号的互转**\n", + "\n", + "  上述四个函数的使用效果如下所示,此处可以明显看出,`tokenizer`分词和传统分词的不同效果,例如`'##cap'`" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012]\n", + "['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n", + "a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .\n" + ] + } + ], + "source": [ + "text = \"a series of escapades demonstrating the adage that what is \" \\\n", + " \"good for the goose is also good for the gander , some of which \" \\\n", + " \"occasionally amuses but none of which amounts to much of a story .\" \n", + "tks = ['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', \n", + " 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', \n", + " 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', \n", + " 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', \n", + " 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n", + "ids = [ 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, \n", + " 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204,\n", + " 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572,\n", + " 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037,\n", + " 2466, 1012]\n", + "\n", + "tokens = tokenizer.tokenize(text)\n", + "print(tokenizer.convert_tokens_to_ids(tokens))\n", + "\n", + "ids = tokenizer.convert_tokens_to_ids(tokens)\n", + "print(tokenizer.convert_ids_to_tokens(ids))\n", + "\n", + "print(tokenizer.convert_tokens_to_string(tokens))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`BertTokenizer`中,还有另外两个函数可以实现分词标注,分别是 **encode 和 decode 函数**,**可以直接实现**\n", + "\n", + "  **文本字符串和词素编号列表的互转**,但是编码过程中会按照`BERT`的规则,**在句子首末加入 [CLS] 和 [SEP]**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012, 102]\n", + "[CLS] a series of escapades demonstrating the adage that what is good for the goose is also good for the gander, some of which occasionally amuses but none of which amounts to much of a story. [SEP]\n" + ] + } + ], + "source": [ + "enc = tokenizer.encode(text)\n", + "print(tokenizer.encode(text))\n", + "dec = tokenizer.decode(enc)\n", + "print(tokenizer.decode(enc))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`encode`函数之上,还有`encode_plus`函数,这也是在数据预处理中,`BertTokenizer`模块最常用到的函数\n", + "\n", + "  **encode 函数的参数**,**encode_plus 函数都有**;**encode 函数词素编号列表**,**encode_plus 函数返回字典**\n", + "\n", + "在`encode_plus`函数的返回值中,字段`input_ids`表示词素编号,其余两个字段后文有详细解释\n", + "\n", + "  **字段 token_type_ids 详见 text_pairs 的示例**,**字段 attention_mask 详见 batch_text 的示例**\n", + "\n", + "在`encode_plus`函数的参数中,参数`add_special_tokens`表示是否按照`BERT`的规则,加入相关特殊字符\n", + "\n", + "  参数`max_length`表示句子截取最大长度(算特殊字符),在参数`truncation=True`时会自动截取\n", + "\n", + "  参数`return_attention_mask`约定返回的字典中是否包括`attention_mask`字段,以上案例如下" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "text = \"a series of escapades demonstrating the adage that what is good for the goose is also good for \"\\\n", + " \"the gander , some of which occasionally amuses but none of which amounts to much of a story .\" \n", + "\n", + "encoded = tokenizer.encode_plus(text=text, add_special_tokens=True, max_length=32, \n", + " truncation=True, return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`encode_plus`函数之上,还有`batch_encode_plus`函数(类似地,在`decode`之上,还有`batch_decode`\n", + "\n", + "  两者参数类似,**batch_encode_plus 函数针对批量文本 batch_text**,**或者批量句对 text_pairs**\n", + "\n", + "在针对批量文本`batch_text`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n", + "\n", + "  可以发现,**attention_mask 字段通过 01 标注出词素序列中该位置是否为补零**,可以用做自注意力的掩模" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 0, 0], [101, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 0, 0, 0, 0, 0, 0, 0], [101, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]}\n" + ] + } + ], + "source": [ + "batch_text = [\"a series of escapades demonstrating the adage that\",\n", + " \"what is good for the goose is also good for the gander\",\n", + " \"some of which occasionally amuses\",\n", + " \"but none of which amounts to much of a story\" ]\n", + "\n", + "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text, padding=True,\n", + " add_special_tokens=True, max_length=16, truncation=True, \n", + " return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "而在针对批量句对`text_pairs`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n", + "\n", + "  可以发现,**token_type_ids 字段通过 01 标注出词素序列中该位置为句对中的第几句**,句对用 [SEP] 分割" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]}\n" + ] + } + ], + "source": [ + "text_pairs = [(\"a series of escapades demonstrating the adage that\",\n", + " \"what is good for the goose is also good for the gander\"),\n", + " (\"some of which occasionally amuses\",\n", + " \"but none of which amounts to much of a story\")]\n", + "\n", + "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=text_pairs, padding=True,\n", + " add_special_tokens=True, max_length=32, truncation=True, \n", + " return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "回到`encode_plus`上,在接下来的示例中,**使用内置的 functools.partial 模块构造 encode 函数**\n", + "\n", + "  接着**使用该函数对 databundle 进行数据预处理**,由于`tokenizer.encode_plus`返回的是一个字典\n", + "\n", + "  读入的是一个字段,所以此处使用`apply_field_more`方法,得到结果自动并入`databundle`中如下" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "functools.partial(, max_length=32, truncation=True, return_attention_mask=True)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + "\n", + "在接下来的`tutorial 3.`中,将会介绍`fastNLP v1.0`中的`dataloader`模块,会涉及本章中\n", + "\n", + "  提到的`collator`模块,`fastNLP`的多框架适应以及完整的数据加载过程,敬请期待" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_3.ipynb b/docs/source/tutorials/fastnlp_tutorial_3.ipynb new file mode 100644 index 00000000..4100105a --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_3.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "213d538c", + "metadata": {}, + "source": [ + "# T3. dataloader 的内部结构和基本使用\n", + "\n", + "  1   fastNLP 中的 dataloader\n", + " \n", + "    1.1   dataloader 的基本介绍\n", + "\n", + "    1.2   dataloader 的函数创建\n", + "\n", + "  2   fastNLP 中 dataloader 的延伸\n", + "\n", + "    2.1   collator 的概念与使用\n", + "\n", + "    2.2   结合 datasets 框架" + ] + }, + { + "cell_type": "markdown", + "id": "85857115", + "metadata": {}, + "source": [ + "## 1. fastNLP 中的 dataloader\n", + "\n", + "### 1.1 dataloader 的基本介绍\n", + "\n", + "在`fastNLP 1.0`的开发中,最关键的开发目标就是**实现 fastNLP 对当前主流机器学习框架**,例如\n", + "\n", + "  **当下流行的 pytorch**,以及**国产的 paddle 、jittor 和 oneflow 的兼容**,扩大受众的同时,也是助力国产\n", + "\n", + "本着分而治之的思想,我们可以将`fastNLP 1.0`对`pytorch`、`paddle`、`jittor`、`oneflow`框架的兼容,划分为\n", + "\n", + "    **对数据预处理**、**批量 batch 的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n", + "\n", + "  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n", + "\n", + "    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n", + "\n", + "    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n", + "\n", + "只有涉及到张量、模型,不同框架才展现出其各自的特色:**pytorch 和 oneflow 中的 tensor 和 nn.Module**\n", + "\n", + "    **在 paddle 中称为 tensor 和 nn.Layer**,**在 jittor 中则称为 Var 和 Module**\n", + "\n", + "    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n", + "\n", + "  针对批量`batch`的处理,作为`fastNLP 1.0`中框架无关部分想框架相关部分的过渡\n", + "\n", + "    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n", + "\n", + "**dataloader 模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n", + "\n", + "    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n", + "\n", + "    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n", + "\n", + "    第三,**batch 内数据格式要匹配框架**,**但 batch 结构需保持一致**,**参数匹配机制**\n", + "\n", + "  对此,`fastNLP 1.0`给出了 **TorchDataLoader 、 PaddleDataLoader 、 JittorDataLoader 和 OneflowDataLoader**\n", + "\n", + "    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n", + "\n", + "名称|参数|属性|功能|内容\n", + "----|----|----|----|----|\n", + " `dataset` | √ | √ | 指定`dataloader`的数据内容 | |\n", + " `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", + " `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", + " `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", + " `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + " `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + " `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", + " `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", + " `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", + " `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n", + " `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n", + " `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |" + ] + }, + { + "cell_type": "markdown", + "id": "60a8a224", + "metadata": {}, + "source": [ + "  论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n", + "\n", + "    包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n", + "\n", + "    以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aca72b49", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + " ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n", + "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'target': tensor([0, 1, 1, 2]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] + } + ], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataset = data_bundle.get_dataset('train')\n", + "evaluate_dataset = data_bundle.get_dataset('dev')\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "print(type(train_dataloader))\n", + "\n", + "import pprint\n", + "\n", + "for batch in train_dataloader:\n", + " print(type(batch), type(batch['input_ids']), list(batch))\n", + " pprint.pprint(batch, width=1)" + ] + }, + { + "cell_type": "markdown", + "id": "9f457a6e", + "metadata": {}, + "source": [ + "之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是 DataSet 类型**,**还可以**\n", + "\n", + "  **是 DataBundle 类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n", + "\n", + "例如下方就是**直接通过 prepare_paddle_dataloader 函数生成基于 PaddleDataLoader 的字典**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7827557d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from fastNLP import prepare_paddle_dataloader\n", + "\n", + "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n", + "\n", + "print(type(dl_bundle['train']))" + ] + }, + { + "cell_type": "markdown", + "id": "d898cf40", + "metadata": {}, + "source": [ + "  而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", + "\n", + "  这里也可以看出`trainer`模块中,**evaluate_dataloaders 的设计允许评测可以针对多个数据集**\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=dl_bundle['train'],\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver='paddle',\n", + "\tdevice='gpu',\n", + "\t...\n", + " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n", + " metrics={'acc': Accuracy()},\n", + "\t...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "d74d0523", + "metadata": {}, + "source": [ + "## 2. fastNLP 中 dataloader 的延伸\n", + "\n", + "### 2.1 collator 的概念与使用\n", + "\n", + "在`fastNLP 1.0`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n", + "\n", + "  例如,**实现序列的补零对齐的核对器 collator 模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n", + "\n", + "在`fastNLP 1.0`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n", + "\n", + "名称|属性|方法|功能|内容\n", + " ----|----|----|----|----|\n", + " `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n", + " `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐  | 字典类型 |\n", + " `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n", + " `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n", + " `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n", + " `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name`  |\n", + " `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d0795b3e", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_dataloader.collate_fn\n", + "\n", + "print(type(train_dataloader.collate_fn))" + ] + }, + { + "cell_type": "markdown", + "id": "5f816ef5", + "metadata": {}, + "source": [ + "此外,还可以 **手动定义 dataloader 中的 collate_fn**,而不是使用`fastNLP 1.0`中自带的`collator`模块\n", + "\n", + "  该函数的定义可以大致如下,需要注意的是,**定义 collate_fn 之前需要了解 batch 作为字典的格式**\n", + "\n", + "  该函数通过`collate_fn`参数传入`dataloader`,**在 batch 分发**(**而不是 batch 划分**)**时调用**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ff8e405e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item['input_ids'])\n", + " max_length[0] = max(len(each_item['input_ids']), max_length[0])\n", + " atten_mask.append(each_item['token_type_ids'])\n", + " max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n", + " labels.append(each_item['attention_mask'])\n", + " max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "id": "487b75fb", + "metadata": {}, + "source": [ + "注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e916d1ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] + } + ], + "source": [ + "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n", + "\n", + "print(type(train_dataloader))\n", + "print(type(train_dataloader.collate_fn))\n", + "\n", + "for batch in train_dataloader:\n", + " pprint.pprint(batch, width=1)" + ] + }, + { + "cell_type": "markdown", + "id": "0bd98365", + "metadata": {}, + "source": [ + "### 2.2 fastNLP 与 datasets 的结合\n", + "\n", + "从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v1.0`数据读取、预处理、加载,整个流程的介绍\n", + "\n", + "  不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n", + "\n", + "**使用 datasets 模块中的 load_dataset 函数**,通过指定数据集两级的名称,示例中即是**GLUE 标准中的 SST-2 数据集**\n", + "\n", + "  即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n", + "\n", + "  之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "91879c30", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.525,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 84.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.54375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 87.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.55,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 88.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 100.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.65,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 104.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.69375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 111.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.66875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.68125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 109.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bc4bfb2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "07538876",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "1b52eafd",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "383"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "del dataset\n",
+    "del sst2data\n",
+    "\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d9443213",
+   "metadata": {},
+   "source": [
+    "## 2. fastNLP 中 models 模块的介绍\n",
+    "\n",
+    "### 2.1  示例一:models 实现 CNN 分类\n",
+    "\n",
+    "  本示例使用`fastNLP 1.0`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
+    "\n",
+    "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n",
+    "\n",
+    "模型使用方面,如上所述,这里使用**基于卷积神经网络 CNN 的预定义文本分类模型 CNNText**,结构如下所示\n",
+    "\n",
+    "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
+    "\n",
+    "    **感受野为 1 、 3 、 5 的卷积算子变换至 30 维、 40 维、 50 维的卷积特征**,再将三者拼接\n",
+    "\n",
+    "  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
+    "\n",
+    "```\n",
+    "CNNText(\n",
+    "  (embed): Embedding(\n",
+    "    (embed): Embedding(5194, 100)\n",
+    "    (dropout): Dropout(p=0.0, inplace=False)\n",
+    "  )\n",
+    "  (conv_pool): ConvMaxpool(\n",
+    "    (convs): ModuleList(\n",
+    "      (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
+    "      (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
+    "      (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
+    "    )\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=120, out_features=2, bias=True)\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "对应到代码上,**从 fastNLP.models.torch 路径下导入 CNNText**,初始化`CNNText`和`optimizer`实例\n",
+    "\n",
+    "  注意:初始化`CNNText`时,**二元组参数 embed 、分类数量 num_classes 是必须传入的**,其中\n",
+    "\n",
+    "    **embed 表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100` 维"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "f6e76e2e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP.models.torch import CNNText\n",
+    "\n",
+    "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
+    "\n",
+    "from torch.optim import AdamW\n",
+    "\n",
+    "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0cc5ca10",
+   "metadata": {},
+   "source": [
+    "  最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "50a13ee5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP import Trainer, Accuracy\n",
+    "\n",
+    "trainer = Trainer(\n",
+    "    model=model,\n",
+    "    driver='torch',\n",
+    "    device=0,  # 'cuda'\n",
+    "    n_epochs=10,\n",
+    "    optimizers=optimizers,\n",
+    "    train_dataloader=train_dataloader,\n",
+    "    evaluate_dataloaders=evaluate_dataloader,\n",
+    "    metrics={'acc': Accuracy()}\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "28903a7d",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[16:21:57] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.654444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 589.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.767778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 691.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.797778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 718.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.803333,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 723.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.807778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 727.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.812222,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 731.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.804444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 724.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.806667,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 726.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5b5c0446",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "e9e70f88",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "344"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6aec2a19",
+   "metadata": {},
+   "source": [
+    "### 2.2  示例二:models 实现 BiLSTM 标注\n",
+    "\n",
+    "  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+    "\n",
+    "    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+    "\n",
+    "  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
+    "\n",
+    "模型使用方面,如上所述,这里使用**基于双向 LSTM +条件随机场 CRF 的标注模型 BiLSTMCRF**,结构如下所示\n",
+    "\n",
+    "  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+    "\n",
+    "```\n",
+    "BiLSTMCRF(\n",
+    "  (embed): Embedding(7590, 100)\n",
+    "  (lstm): LSTM(\n",
+    "    (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=200, out_features=9, bias=True)\n",
+    "  (crf): ConditionalRandomField()\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "数据使用方面,此处仍然**使用 datasets 模块中的 load_dataset 函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+    "\n",
+    "  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "03e66686",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "593bc03ed5914953ab94268ff2f01710",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/3 [00:00[16:23:41] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.169014,\n",
+       "  \"pre#F1\": 0.170732,\n",
+       "  \"rec#F1\": 0.167331\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.361809,\n",
+       "  \"pre#F1\": 0.312139,\n",
+       "  \"rec#F1\": 0.430279\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.525,\n",
+       "  \"pre#F1\": 0.475728,\n",
+       "  \"rec#F1\": 0.585657\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.627306,\n",
+       "  \"pre#F1\": 0.584192,\n",
+       "  \"rec#F1\": 0.677291\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.710937,\n",
+       "  \"pre#F1\": 0.697318,\n",
+       "  \"rec#F1\": 0.7251\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.739563,\n",
+       "  \"pre#F1\": 0.738095,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.748491,\n",
+       "  \"pre#F1\": 0.756098,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.716763,\n",
+       "  \"pre#F1\": 0.69403,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.768293,\n",
+       "  \"pre#F1\": 0.784232,\n",
+       "  \"rec#F1\": 0.752988\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.757692,\n",
+       "  \"pre#F1\": 0.732342,\n",
+       "  \"rec#F1\": 0.784861\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "96bae094",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_5.ipynb b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
new file mode 100644
index 00000000..ab759feb
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
@@ -0,0 +1,1242 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fdd7ff16",
+   "metadata": {},
+   "source": [
+    "# T5. trainer 和 evaluator 的深入介绍\n",
+    "\n",
+    "  1   fastNLP 中 driver 的补充介绍\n",
+    " \n",
+    "    1.1   trainer 和 driver 的构想 \n",
+    "\n",
+    "    1.2   device 与 多卡训练\n",
+    "\n",
+    "  2   fastNLP 中的更多 metric 类型\n",
+    "\n",
+    "    2.1   预定义的 metric 类型\n",
+    "\n",
+    "    2.2   自定义的 metric 类型\n",
+    "\n",
+    "  3   fastNLP 中 trainer 的补充介绍\n",
+    "\n",
+    "    3.1   trainer 的内部结构"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "08752c5a",
+   "metadata": {
+    "pycharm": {
+     "name": "#%% md\n"
+    }
+   },
+   "source": [
+    "## 1. fastNLP 中 driver 的补充介绍\n",
+    "\n",
+    "### 1.1  trainer 和 driver 的构想\n",
+    "\n",
+    "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n",
+    "\n",
+    "  在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n",
+    "\n",
+    "    **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n",
+    "\n",
+    "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n",
+    "\n",
+    "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
+    "\n",
+    "    划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+    "\n",
+    "    以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+    "\n",
+    "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n",
+    "|----|----|----|\n",
+    "| try: | try: |  |\n",
+    "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: |  |\n",
+    "| for step in 1:total_steps: | for step in 1:total_steps: |  |\n",
+    "| batch = fetch_batch() | batch = fetch_batch() |  |\n",
+    "| loss = model.forward(batch)  |  | loss = model.forward(batch)  |\n",
+    "| loss.backward() |  | loss.backward() |\n",
+    "| model.clear_grad() |  | model.clear_grad() |\n",
+    "| model.update() |  | model.update() |\n",
+    "| if need_save: | if need_save: |  |\n",
+    "| model.save() |  | model.save() |\n",
+    "| except: | except: |  |\n",
+    "| process_exception() | process_exception() |  |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3e55f07b",
+   "metadata": {},
+   "source": [
+    "  对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
+    "\n",
+    "    划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+    "\n",
+    "    以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+    "\n",
+    "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n",
+    "|----|----|----|\n",
+    "| try: | try: |  |\n",
+    "| model.set_eval() | model.set_eval() |  |\n",
+    "| for step in 1:total_steps: | for step in 1:total_steps: |  |\n",
+    "| batch = fetch_batch() | batch = fetch_batch() |  |\n",
+    "| outputs = model.evaluate(batch)  |  | outputs = model.evaluate(batch)  |\n",
+    "| metric.compute(batch, outputs) |  | metric.compute(batch, outputs) |\n",
+    "| results = metric.get_metric() | results = metric.get_metric() |  |\n",
+    "| except: | except: |  |\n",
+    "| process_exception() | process_exception() |  |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "94ba11c6",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n",
+    "\n",
+    "    **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n",
+    "\n",
+    "  而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n",
+    "\n",
+    "    并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
+    "\n",
+    "    对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ab1cea7d",
+   "metadata": {},
+   "source": [
+    "### 1.2  device 与 多卡训练\n",
+    "\n",
+    "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n",
+    "\n",
+    "  由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n",
+    "\n",
+    "    数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
+    "\n",
+    "    模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
+    "\n",
+    "  例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n",
+    "\n",
+    "    指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
+    "\n",
+    "    在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
+    "    \n",
+    "```python\n",
+    "trainer = Trainer(\n",
+    "        model=model,                                # model 基于 pytorch 实现 \n",
+    "        train_dataloader=train_dataloader,\n",
+    "        optimizers=optimizer,\n",
+    "        ...\n",
+    "        driver='torch',                             # driver 使用 torch_driver \n",
+    "        device=[0, 1],                              # gpu 选择 cuda:0 + cuda:1\n",
+    "        ...\n",
+    "        evaluate_dataloaders=evaluate_dataloader,\n",
+    "        metrics={'acc': Accuracy()},\n",
+    "        ...\n",
+    "    )\n",
+    "\n",
+    "class Accuracy(Metric):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        self.register_element(name='total', value=0, aggregate_method='sum')\n",
+    "        self.register_element(name='right', value=0, aggregate_method='sum')\n",
+    "```\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e2e0a210",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8d19220c",
+   "metadata": {},
+   "source": [
+    "## 2. fastNLP 中的更多 metric 类型\n",
+    "\n",
+    "### 2.1  预定义的 metric 类型\n",
+    "\n",
+    "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n",
+    "\n",
+    "  包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
+    "\n",
+    "    **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n",
+    "\n",
+    "    **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n",
+    "\n",
+    "代码名称|简要介绍|代码路径\n",
+    "----|----|----|\n",
+    " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
+    " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
+    " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
+    " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
+    " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fdc083a3",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "  如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
+    "\n",
+    "    **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n",
+    "\n",
+    "\n",
+    "### 2.1.1  Accuracy 与 TransformersAccuracy\n",
+    "\n",
+    "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
+    "\n",
+    "  `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n",
+    "\n",
+    "  一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
+    "\n",
+    "  **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n",
+    "\n",
+    "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
+    "\n",
+    "  在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
+    "\n",
+    "\n",
+    "### 2.1.2  ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
+    "\n",
+    "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
+    "\n",
+    "  两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n",
+    "\n",
+    "    `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n",
+    "\n",
+    "    三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
+    "\n",
+    "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
+    "\n",
+    "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
+    "\n",
+    "  **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
+    "\n",
+    "    **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n",
+    "\n",
+    "  **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n",
+    "\n",
+    "    **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
+    "\n",
+    "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
+    "\n",
+    "    **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n",
+    "\n",
+    "  在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
+    "\n",
+    "    对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
+    "\n",
+    "    因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
+    "\n",
+    "      或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
+    "\n",
+    "    最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
+    "\n",
+    "```python\n",
+    "from fastNLP import Vocabulary\n",
+    "from fastNLP import ClassifyFPreRecMetric\n",
+    "\n",
+    "tag_vocab = Vocabulary(padding=None, unknown=None)            # 记录序号与标签之间的映射\n",
+    "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
+    "                        'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
+    "                        'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
+    "                        'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
+    "                        'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ])  # CoNLL-2003 中的 pos_tags\n",
+    "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
+    "\n",
+    "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab,          \n",
+    "                                ignore_labels=ignore_labels,  # 表示评测/优化中不考虑上述标签的正误/损失\n",
+    "                                only_gross=True,              # 默认为 True 表示输出所有类别的综合统计结果\n",
+    "                                f_type='micro')               # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
+    "metrics = {'F1': FPreRec}\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8a22f522",
+   "metadata": {},
+   "source": [
+    "### 2.2  自定义的 metric 类型\n",
+    "\n",
+    "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n",
+    "\n",
+    "    也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n",
+    "\n",
+    "  在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
+    "\n",
+    "  在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n",
+    "\n",
+    "    **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
+    "\n",
+    "    在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
+    "\n",
+    "    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
+    "\n",
+    "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
+    "\n",
+    "    其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
+    "\n",
+    "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "08a872e9",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Metric\n", + "\n", + "class MyMetric(Metric):\n", + "\n", + " def __init__(self):\n", + " Metric.__init__(self)\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + "\n", + " def update(self, pred, target):\n", + " self.total_num += target.size(0)\n", + " self.right_num += target.eq(pred).sum().item()\n", + "\n", + " def get_metric(self, reset=True):\n", + " acc = self.right_num / self.total_num\n", + " if reset:\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + " return {'prefix': acc}" + ] + }, + { + "cell_type": "markdown", + "id": "0155f447", + "metadata": {}, + "source": [ + "  数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00 0: # 如果设置了 num_eval_sanity_batch\n", + "\t\ton_sanity_check_begin(trainer)\n", + "\t\ton_sanity_check_end(trainer, sanity_check_res)\n", + "\ttry:\n", + "\t\ton_train_begin(trainer)\n", + "\t\twhile cur_epoch_idx < n_epochs:\n", + "\t\t\ton_train_epoch_begin(trainer)\n", + "\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n", + "\t\t\t\ton_fetch_data_begin(trainer)\n", + "\t\t\t\tbatch = next(dataloader)\n", + "\t\t\t\ton_fetch_data_end(trainer)\n", + "\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n", + "\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n", + "\t\t\t\ton_after_backward(trainer)\n", + "\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_train_batch_end(trainer)\n", + "\t\t\ton_train_epoch_end(trainer)\n", + "\texcept BaseException:\n", + "\t\tself.on_exception(trainer, exception)\n", + "\tfinally:\n", + "\t\ton_train_end(trainer)\n", + "``` -->" + ] + }, + { + "cell_type": "markdown", + "id": "1e21df35", + "metadata": {}, + "source": [ + "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", + "\n", + "  字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "926a9c50", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " input_mapping=input_mapping,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'suffix': MyMetric()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b1b2e8b7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", + "\n", + "|名称|功能|默认值|\n", + "|----|----|----|\n", + "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", + "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", + "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", + "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", + "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:30:35] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.6875\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.825\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1abfa0a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_6.ipynb b/docs/source/tutorials/fastnlp_tutorial_6.ipynb new file mode 100644 index 00000000..63f7481e --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_6.ipynb @@ -0,0 +1,1646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T6. fastNLP 与 paddle 或 jittor 的结合\n", + "\n", + "  1   fastNLP 结合 paddle 训练模型\n", + " \n", + "    1.1   关于 paddle 的简单介绍\n", + "\n", + "    1.2   使用 paddle 搭建并训练模型\n", + "\n", + "  2   fastNLP 结合 jittor 训练模型\n", + "\n", + "    2.1   关于 jittor 的简单介绍\n", + "\n", + "    2.2   使用 jittor 搭建并训练模型\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08752c5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00 True\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "print(type(train_dataset), isinstance(train_dataset, DataSet))\n", + "\n", + "from fastNLP.io import DataBundle\n", + "\n", + "data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})" + ] + }, + { + "cell_type": "markdown", + "id": "57a3272f", + "metadata": {}, + "source": [ + "## 1. fastNLP 结合 paddle 训练模型\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e31b3198", + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "\n", + "\n", + "class ClsByPaddle(nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, dropout=0.5):\n", + " nn.Layer.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n", + " \n", + " self.conv1 = nn.Sequential(nn.Conv1D(embedding_dim, 30, 1, padding=0), nn.ReLU())\n", + " self.conv2 = nn.Sequential(nn.Conv1D(embedding_dim, 40, 3, padding=1), nn.ReLU())\n", + " self.conv3 = nn.Sequential(nn.Conv1D(embedding_dim, 50, 5, padding=2), nn.ReLU())\n", + "\n", + " self.mlp = nn.Sequential(('dropout', nn.Dropout(p=dropout)),\n", + " ('linear_1', nn.Linear(120, hidden_dim)),\n", + " ('activate', nn.ReLU()),\n", + " ('linear_2', nn.Linear(hidden_dim, output_dim)))\n", + " \n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def forward(self, words):\n", + " output = self.embedding(words).transpose([0, 2, 1])\n", + " conv1, conv2, conv3 = self.conv1(output), self.conv2(output), self.conv3(output)\n", + "\n", + " pool1 = F.max_pool1d(conv1, conv1.shape[-1]).squeeze(2)\n", + " pool2 = F.max_pool1d(conv2, conv2.shape[-1]).squeeze(2)\n", + " pool3 = F.max_pool1d(conv3, conv3.shape[-1]).squeeze(2)\n", + "\n", + " pool = paddle.concat([pool1, pool2, pool3], axis=1)\n", + " output = self.mlp(pool)\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = paddle.stack((1 - target, target), axis=1).cast(pred.dtype)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = paddle.argmax(pred, axis=-1)\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c63b030f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0604 21:02:25.453869 19014 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 11.1, Runtime API Version: 10.2\n", + "W0604 21:02:26.061690 19014 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.\n" + ] + }, + { + "data": { + "text/plain": [ + "ClsByPaddle(\n", + " (embedding): Embedding(8458, 100, sparse=False)\n", + " (conv1): Sequential(\n", + " (0): Conv1D(100, 30, kernel_size=[1], data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv1D(100, 40, kernel_size=[3], padding=1, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv1D(100, 50, kernel_size=[5], padding=2, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (mlp): Sequential(\n", + " (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)\n", + " (linear_1): Linear(in_features=120, out_features=64, dtype=float32)\n", + " (activate): ReLU()\n", + " (linear_2): Linear(in_features=64, out_features=2, dtype=float32)\n", + " )\n", + " (loss_fn): MSELoss()\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2997c0aa", + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.optimizer import AdamW\n", + "\n", + "optimizers = AdamW(parameters=model.parameters(), learning_rate=5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ead35fb8", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_paddle_dataloader\n", + "\n", + "train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "25e8da83", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='paddle',\n", + " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'], \n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d63c5d74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:03:08] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n",
+       "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n",
+       "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n",
+       "safe. \n",
+       "Deprecated in NumPy 1.20; for more details and guidance: \n",
+       "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+       "  if data.dtype == np.object:\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.78125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.80625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 129.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10) " + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0b3c", + "metadata": {}, + "source": [ + "## 2. fastNLP 结合 jittor 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c600191d", + "metadata": {}, + "outputs": [], + "source": [ + "import jittor\n", + "import jittor.nn as nn\n", + "\n", + "from jittor import Module\n", + "\n", + "\n", + "class ClsByJittor(Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " Module.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", + " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", + "\n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def execute(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a94ed8c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6d15ebc1", + "metadata": {}, + "outputs": [], + "source": [ + "from jittor.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "95d8d09e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_jittor_dataloader\n", + "\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "917eab81", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='jittor',\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7c4ac5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta:    0s \n",
+      "\n",
+      "Compiling Operators(31/31) used: 7.31s eta:    0s \n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.61875,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 99\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 112\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.725,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 116\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.73125,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 117\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb new file mode 100644 index 00000000..af8e60a0 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "  从这篇开始,我们将开启 **fastNLP v1.0 tutorial 的 example 系列**,在接下来的\n", + "\n", + "  每篇`tutorial`里,我们将会介绍`fastNLP v1.0`在自然语言处理任务上的应用实例" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[, , ]\n" + ] + } + ], + "source": [ + "from pygments.plugin import find_plugin_lexers\n", + "print(list(find_plugin_lexers()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n", + "\n", + "  1   基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n", + "\n", + "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", + "\n", + "  3   模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", + "\n", + "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", + "\n", + "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", + "\n", + "**GLUE**,**全称 General Language Understanding Evaluation**,**通用语言理解评估**,\n", + "\n", + "  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + "    **CoLA**,文本分类任务,预测单句语法正误分类;**SST-2**,文本分类任务,预测单句情感二分类\n", + "\n", + "    **MRPC**,句对分类任务,预测句对语义一致性;**STS-B**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + "    **QQP**,句对分类任务,预测问题对语义一致性;**MNLI**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + "    **QNLI / RTE / WNLI**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + "    此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**SST**,**全称`Stanford Sentiment Treebank**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", + "\n", + "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5915debacf9443986b5b3b34870b303", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.9,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 288.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 284.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.88125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 282.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 280.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.865625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 277.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 附:`DistilBertForSequenceClassification`模块结构\n",
+    "\n",
+    "```\n",
+    "\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.13 ('fnlp-paddle')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
new file mode 100644
index 00000000..588ee8c3
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
@@ -0,0 +1,1082 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E2. 使用 Bert + prompt 完成 SST-2 分类\n",
+    "\n",
+    "  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n",
+    "\n",
+    "  2   准备工作:`P-Tuning v2`原理概述、`P-Tuning v2`模型搭建\n",
+    "\n",
+    "  3   模型训练:加载`tokenizer`、预处理`dataset`、模型训练与分析"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n",
+    "\n",
+    "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`prompt-based tuning`方式\n",
+    "\n",
+    "    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n",
+    "\n",
+    "    将首先简单介绍提示学习模型的研究,以及与`fastNLP v1.0`结合的优势\n",
+    "\n",
+    "**prompt**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的 **PET 模型**\n",
+    "\n",
+    "    全称 **Pattern-Exploiting Training**,虽然文中并没有提到`prompt`的说法,但仍被视为开山之作\n",
+    "\n",
+    "  其大致思路包括,对于文本分类任务,假定输入文本为`\" X . \"`,设计**输入模板 template**,**后来被称为 prompt**\n",
+    "\n",
+    "    将输入重构为`\" X . It is [MASK] . \"`,**诱导或刺激语言模型在 [MASK] 位置生成含有情感倾向的词汇**\n",
+    "\n",
+    "    接着将该词汇**输入分类器中**,**后来被称为 verbalizer**,从而得到该语句对应的情感倾向,实现文本分类\n",
+    "\n",
+    "  其主要贡献在于,通过构造`prompt`,诱导/刺激预训练模型生成期望适应下游任务特征,适合少样本学习的需求\n",
+    "\n",
+    "\n",
+    "\n",
+    "**prompt-based tuning**,**基于提示的微调**,将`prompt`应用于**参数高效微调**,**parameter-efficient tuning**\n",
+    "\n",
+    "  通过**设计模板调整模型输入**或者**调整模型内部状态**,**固定预训练模型**,**诱导/刺激模型**调整输出以适应\n",
+    "\n",
+    "  当前任务,极大降低了训练开销,也省去了`verbalizer`的构造,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)、[DeltaTuning综述](https://arxiv.org/pdf/2203.06904.pdf)\n",
+    "\n",
+    "    以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n",
+    "\n",
+    "  **案例一**:**PrefixTuning**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n",
+    "\n",
+    "    其主要贡献在于,**提出连续的、非人工构造的、任务导向的 prompt**,即**前缀 prefix**,**调整**\n",
+    "\n",
+    "      **模型内部更新状态**,诱导模型在特定任务下生成期望目标,降低优化难度,提升微调效果\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`和`BART`,主要面向生成任务`NLG`,如`table-to-text`和摘要\n",
+    "\n",
+    "  **案例二**:**P-Tuning v1**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n",
+    "\n",
+    "    其主要贡献在于,**通过连续的、非人工构造的 prompt 调整模型输入**,取代原先基于单词设计的\n",
+    "\n",
+    "      但离散且不易于优化的`prompt`;同时也**证明了 GPT2 在语言理解任务上仍然是可以胜任的**\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`,主要面向知识探测`knowledge probing`和自然语言理解`NLU`\n",
+    "\n",
+    "  **案例三**:**PromptTuning**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n",
+    "\n",
+    "    其主要贡献在于,通过连续的`prompt`调整模型输入,**证明了 prompt-based tuning 的效果**\n",
+    "\n",
+    "      **随模型参数量的增加而提升**,最终**在 10B 左右追上了全参数微调 fine-tuning 的效果**\n",
+    "\n",
+    "    其主要面向自然语言理解`NLU`,通过为每个任务定义不同的`prompt`,从而支持多任务语境\n",
+    "\n",
+    "通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n",
+    "\n",
+    "  目前,加载预训练模型的主流方法是使用**transformers 模块**,而实现微调的框架则\n",
+    "\n",
+    "    可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n",
+    "\n",
+    "  因此,**使用 fastNLP v1.0 实现 prompt-based tuning**,可以**很好地解决 paddle 等框架**\n",
+    "\n",
+    "    **和 transformers 模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n",
+    "\n",
+    "本示例仍使用了`tutorial-E1`的`SST-2`数据集、`distilbert-base-uncased`模型(便于比较\n",
+    "\n",
+    "  使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST-2`二分类任务"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)\n", + "\n", + "task = 'sst2'\n", + "model_checkpoint = 'distilbert-base-uncased' # 'bert-base-uncased'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n", + "\n", + "  本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v1.0`结合的案例\n", + "\n", + "    以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v1.0`的代码实践\n", + "\n", + "**P-Tuning v2**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n", + "\n", + "  其主要贡献在于,**在 PrefixTuning 等深度提示学习基础上**,**提升了其在分类标注等 NLU 任务的表现**\n", + "\n", + "    并使之在中等规模模型,主要是**参数量在 100M-1B 区间的模型上**,**获得与全参数微调相同的效果**\n", + "\n", + "  其结构如图所示,通过**在输入序列的分类符 [CLS] 之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n", + "\n", + "    **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n", + "\n", + "\n", + "\n", + "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n", + "\n", + "    固定其参数不参与训练,**设置 pre_seq_len 长的 prefix_tokens 作为输入的提示前缀序列**\n", + "\n", + "  **使用基于 nn.Embedding 的 prefix_encoder 为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n", + "\n", + "    拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩码`attention_mask`\n", + "\n", + "  将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括 loss 和 logits**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", + " for param in self.back_bone.parameters():\n", + " param.requires_grad = False\n", + " \n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " \n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", + " \n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n", + "\n", + "  根据`P-Tuning v2`论文:`Generally, simple classification tasks prefer shorter prompts (less than 20)`\n", + "\n", + "  此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", + "\n", + "  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n", + "\n", + "    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", + "\n", + "    数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n", + "\n", + "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n", + "\n", + "  构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21cbd92c3397497d84dc10f017ec96f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00[22:53:00] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.540625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 173.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.5,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 160.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.509375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 163.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.634375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 203.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.6125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 196.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 216.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.64375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 206.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.665625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 213.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.659375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 211.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.696875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 223.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n", + "\n", + "  在`100M-1B`区间的模型,但是,**distilbert-base 的参数量仅为 66M**,无法触及其下限\n", + "\n", + "另一方面,**fastNLP v1.0 不支持 jupyter 多卡**,所以无法在笔者的电脑/服务器上,完成\n", + "\n", + "  合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
new file mode 100644
index 00000000..a5883416
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
@@ -0,0 +1,1086 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E3. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析\n",
+    "\n",
+    "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何使用 `paddlenlp` 自然语言处理库和 `fastNLP` 来完成比较简单的情感分析任务。\n",
+    "\n",
+    "1. 基础介绍:飞桨自然语言处理库 ``paddlenlp`` 和语义理解框架 ``ERNIE``\n",
+    "\n",
+    "2. 准备工作:使用 ``tokenizer`` 处理数据并构造 ``dataloader``\n",
+    "\n",
+    "3. 模型训练:加载 ``ERNIE`` 预训练模型,使用 ``fastNLP`` 进行训练"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:飞桨自然语言处理库 paddlenlp 和语义理解框架 ERNIE\n",
+    "\n",
+    "#### 1.1 飞桨自然语言处理库 paddlenlp\n",
+    "\n",
+    "``paddlenlp`` 是由百度以飞桨 ``PaddlePaddle`` 为核心开发的自然语言处理库,集成了多个数据集和 NLP 模型,包括百度自研的语义理解框架 ``ERNIE`` 。在本篇教程中,我们会以 ``paddlenlp`` 为基础,使用模型 ``ERNIE`` 完成中文情感分析任务。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2.3.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "\n",
+    "import paddle\n",
+    "import paddlenlp\n",
+    "from paddlenlp.transformers import AutoTokenizer\n",
+    "from paddlenlp.transformers import AutoModelForSequenceClassification\n",
+    "\n",
+    "print(paddlenlp.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 1.2 语义理解框架 ERNIE\n",
+    "\n",
+    "``ERNIE(Enhanced Representation from kNowledge IntEgration)`` 是百度提出的基于知识增强的持续学习语义理解框架,至今已有 ``ERNIE 2.0``、``ERNIE 3.0``、``ERNIE-M``、``ERNIE-tiny`` 等多种预训练模型。``ERNIE 1.0`` 采用``Transformer Encoder`` 作为其语义表示的骨架,并改进了两种 ``mask`` 策略,分别为基于**短语**和**实体**(人名、组织等)的策略。在 ``ERNIE`` 中,由多个字组成的短语或者实体将作为一个统一单元,在训练的时候被统一地 ``mask`` 掉,这样可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n",
+    "\n",
+    "\n",
+    "\n",
+    "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n",
+    "\n",
+    "\n",
+    "\n",
+    "接下来,我们将展示如何在 ``fastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. 使用 tokenizer 处理数据并构造 dataloader\n",
+    "\n",
+    "#### 2.1 加载中文数据集 ChnSentiCorp\n",
+    "\n",
+    "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "训练集大小: 9600\n",
+      "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n",
+      "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n",
+      "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n"
+     ]
+    }
+   ],
+   "source": [
+    "from paddlenlp.datasets import load_dataset\n",
+    "\n",
+    "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n",
+    "print(\"训练集大小:\", len(train_dataset))\n",
+    "for i in range(3):\n",
+    "    print(train_dataset[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.2 处理数据\n",
+    "\n",
+    "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n",
+    "- 参数 ``max_length`` 代表句子的最大长度;\n",
+    "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n",
+    "- ``truncation=True`` 表示将长度过长的句子进行截断。\n",
+    "\n",
+    "至此,我们得到了每条数据长度均相同的数据集。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[32m[2022-06-22 21:31:04,168] [    INFO]\u001b[0m - We are using  to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+      "\u001b[32m[2022-06-22 21:31:04,171] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': '', 'input_ids': [1, 352, 790, 1252, 409, 283, 509, 5, 250, 196, 113, 10, 58, 518, 4, 9, 128, 70, 1495, 1855, 339, 293, 45, 302, 233, 554, 4, 544, 637, 1134, 774, 6, 494, 2068, 6, 278, 191, 6, 634, 99, 6, 2678, 144, 7, 149, 1573, 62, 12043, 661, 737, 371, 435, 7, 689, 4, 255, 201, 559, 407, 1308, 12043, 2275, 1110, 11, 19, 842, 5, 1207, 878, 4, 196, 198, 321, 96, 4, 16, 93, 291, 464, 1099, 10, 692, 811, 12043, 392, 5, 748, 1134, 10, 213, 220, 5, 4, 201, 559, 723, 595, 12043, 231, 112, 1114, 4, 7, 689, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
+     ]
+    }
+   ],
+   "source": [
+    "max_len = 128\n",
+    "model_checkpoint = \"ernie-1.0-base-zh\"\n",
+    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
+    "def _process(data):\n",
+    "    data.update(tokenizer(\n",
+    "        data[\"text\"],\n",
+    "        max_length=max_len,\n",
+    "        padding=\"max_length\",\n",
+    "        truncation=True,\n",
+    "        return_attention_mask=True,\n",
+    "    ))\n",
+    "    return data\n",
+    "\n",
+    "train_dataset.map(_process, num_workers=5)\n",
+    "val_dataset.map(_process, num_workers=5)\n",
+    "test_dataset.map(_process, num_workers=5)\n",
+    "\n",
+    "print(train_dataset[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "得到数据集之后,我们便可以将数据集包裹在 ``PaddleDataLoader`` 中,用于之后的训练。``fastNLP`` 提供的 ``PaddleDataLoader`` 拓展了 ``paddle.io.DataLoader`` 的功能,详情可以查看相关的文档。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP.core import PaddleDataLoader\n",
+    "import paddle.nn as nn\n",
+    "\n",
+    "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n",
+    "val_dataloader = PaddleDataLoader(val_dataset, batch_size=32, shuffle=False)\n",
+    "test_dataloader = PaddleDataLoader(test_dataset, batch_size=1, shuffle=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 3. 模型训练:加载 ERNIE 预训练模型,使用 fastNLP 进行训练\n",
+    "\n",
+    "#### 3.1 使用 ERNIE 预训练模型\n",
+    "\n",
+    "为了实现文本分类,我们首先需要定义文本分类的模型。``paddlenlp.transformers`` 提供了模型 ``AutoModelForSequenceClassification``,我们可以利用它来加载不同权重的文本分类模型。在 ``fastNLP`` 中,我们可以定义 ``train_step`` 和 ``evaluate_step`` 函数来实现训练和验证过程中的不同行为。\n",
+    "\n",
+    "- ``train_step`` 函数在获得返回值 ``logits`` (大小为 ``(batch_size, num_labels)``)后计算交叉熵损失 ``CrossEntropyLoss``,然后将 ``loss`` 放在字典中返回。``fastNLP`` 也支持返回 ``dataclass`` 类型的训练结果,但二者都需要包含名为 **loss** 的键或成员。\n",
+    "- ``evaluate_step`` 函数在获得返回值 ``logits`` 后,将 ``logits`` 和标签 ``label`` 放在字典中返回。\n",
+    "\n",
+    "这两个函数的参数均为数据集中字典**键**的子集,``fastNLP`` 会自动进行参数匹配然后输入到模型中。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[32m[2022-06-22 21:31:15,577] [    INFO]\u001b[0m - We are using  to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+      "\u001b[32m[2022-06-22 21:31:15,580] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "import paddle.nn as nn\n",
+    "\n",
+    "class SeqClsModel(nn.Layer):\n",
+    "    def __init__(self, model_checkpoint, num_labels):\n",
+    "        super(SeqClsModel, self).__init__()\n",
+    "        self.model = AutoModelForSequenceClassification.from_pretrained(\n",
+    "            model_checkpoint,\n",
+    "            num_classes=num_labels,\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
+    "        logits = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
+    "        return logits\n",
+    "\n",
+    "    def train_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+    "        logits = self(input_ids, attention_mask, token_type_ids)\n",
+    "        loss = nn.CrossEntropyLoss()(logits, label)\n",
+    "        return {\"loss\": loss}\n",
+    "\n",
+    "    def evaluate_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+    "        logits = self(input_ids, attention_mask, token_type_ids)\n",
+    "        return {'pred': logits, 'target': label}\n",
+    "\n",
+    "model = SeqClsModel(model_checkpoint, num_labels=2)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 3.2 设置参数并使用 Trainer 开始训练\n",
+    "\n",
+    "现在我们可以着手使用 ``fastNLP.Trainer`` 进行训练了。\n",
+    "\n",
+    "首先,为了高效地训练 ``ERNIE`` 模型,我们最好为学习率指定一定的策略。``paddlenlp`` 提供的 ``LinearDecayWithWarmup`` 可以令学习率在一段时间内从 0 开始线性地增长(预热),然后再线性地衰减至 0 。在本篇教程中,我们将学习率设置为 ``5e-5``,预热时间为 ``0.1``,然后将得到的的 ``lr_scheduler`` 赋值给 ``AdamW`` 优化器。\n",
+    "\n",
+    "其次,我们还可以为 ``Trainer`` 指定多个 ``Callback`` 来在基础的训练过程之外进行额外的定制操作。在本篇教程中,我们使用的 ``Callback`` 有以下三种:\n",
+    "\n",
+    "- ``LRSchedCallback`` - 由于我们使用了 ``Scheduler``,因此需要将 ``lr_scheduler`` 传给该 ``Callback`` 以在训练中进行更新。\n",
+    "- ``LoadBestModelCallback`` - 该 ``Callback`` 会评估结果中的 ``'acc#accuracy'`` 值,保存训练中出现的正确率最高的模型,并在训练结束时加载到模型上,方便对模型进行测试和评估。\n",
+    "\n",
+    "在 ``Trainer`` 中,我们还可以设置 ``metrics`` 来衡量模型的表现。``Accuracy`` 能够根据传入的预测值和真实值计算出模型预测的正确率。还记得模型中 ``evaluate_step`` 函数的返回值吗?键 ``pred`` 和 ``target`` 分别为 ``Accuracy.update`` 的参数名,在验证过程中 ``fastNLP`` 会自动将键和参数名匹配从而计算出正确率,这也是我们规定模型需要返回字典类型数据的原因。\n",
+    "\n",
+    "``Accuracy`` 的返回值包含三个部分:``acc``、``total`` 和 ``correct``,分别代表 ``正确率``、 ``数据总数`` 和 ``预测正确的数目``,这让您能够直观地知晓训练中模型的变化,``LoadBestModelCallback`` 的参数 ``'acc#accuracy'`` 也正是代表了 ``accuracy`` 指标的 ``acc`` 结果。\n",
+    "\n",
+    "在设定好参数之后,调用 ``run`` 函数便可以进行训练和验证了。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[21:31:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.895833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1075.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.8975,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1077.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.911667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1094.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9225,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9275,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1113.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.930833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1117.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9375,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.941667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1130.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:28] INFO     Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n",
+       "                    6-22-21_29_12_898095/best_so_far with                                    \n",
+       "                    acc#accuracy: 0.941667...                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:34] INFO     Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n",
+       "                    98095/best_so_far...                                                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import LRSchedCallback, LoadBestModelCallback\n", + "from fastNLP import Trainer, Accuracy\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 2\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks = [\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"paddle\",\n", + " optimizers=optimizer,\n", + " device=0,\n", + " n_epochs=n_epochs,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " evaluate_every=60,\n", + " metrics={\"accuracy\": Accuracy()},\n", + " callbacks=callbacks,\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 测试和评估\n", + "\n", + "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``fastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+       "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+       "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+       "集算什么??简直是画蛇添足!!']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+       "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+       "]\n",
+       "
\n" + ], + "text/plain": [ + "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['交通方便;环境很好;服务态度很好 房间较小']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['交通方便;环境很好;服务态度很好 房间较小']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+       "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+       "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+       "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+       "。']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from fastNLP import Evaluator\n",
+    "def test_batch_step_fn(evaluator, batch):\n",
+    "    input_ids = batch[\"input_ids\"]\n",
+    "    attention_mask = batch[\"attention_mask\"]\n",
+    "    token_type_ids = batch[\"token_type_ids\"]\n",
+    "    logits = model(input_ids, attention_mask, token_type_ids)\n",
+    "    predict = logits.argmax().item()\n",
+    "    print(\"text:\", batch['text'])\n",
+    "    print(\"labels:\", predict)\n",
+    "\n",
+    "evaluator = Evaluator(\n",
+    "    model=model,\n",
+    "    dataloaders=test_dataloader,\n",
+    "    driver=\"paddle\",\n",
+    "    device=0,\n",
+    "    evaluate_batch_step_fn=test_batch_step_fn,\n",
+    ")\n",
+    "evaluator.run(5)    "
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.13 ('fnlp-paddle')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
new file mode 100644
index 00000000..439d7f9f
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
@@ -0,0 +1,1510 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E4. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务\n",
+    "\n",
+    "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何在 `fastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n",
+    "\n",
+    "1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n",
+    "\n",
+    "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n",
+    "\n",
+    "- 多项选择:让模型从多个答案选项中选出正确答案\n",
+    "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n",
+    "- 自由回答:不做限制,让模型自行生成答案\n",
+    "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n",
+    "\n",
+    "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n",
+    "\n",
+    "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2.3.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "import paddle\n",
+    "import paddlenlp\n",
+    "\n",
+    "print(paddlenlp.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "\u001b[32m[2022-06-27 19:22:46,998] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n",
+      "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n",
+      "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n",
+      "训练集大小: 14520\n",
+      "验证集大小: 1417\n"
+     ]
+    }
+   ],
+   "source": [
+    "from paddlenlp.datasets import load_dataset\n",
+    "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n",
+    "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n",
+    "for i in range(3):\n",
+    "    print(train_dataset[i])\n",
+    "print(\"训练集大小:\", len(train_dataset))\n",
+    "print(\"验证集大小:\", len(val_dataset))\n",
+    "\n",
+    "MODEL_NAME = \"ernie-1.0-base-zh\"\n",
+    "from paddlenlp.transformers import ErnieTokenizer\n",
+    "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.1 处理训练集\n",
+    "\n",
+    "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2\n",
+      "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n"
+     ]
+    }
+   ],
+   "source": [
+    "result = tokenizer(\n",
+    "    [train_dataset[0][\"question\"]],\n",
+    "    [train_dataset[0][\"context\"]],\n",
+    "    stride=128,\n",
+    "    max_length=256,\n",
+    "    padding=\"max_length\",\n",
+    "    return_dict=False\n",
+    ")\n",
+    "\n",
+    "print(len(result))\n",
+    "print(result[0].keys())"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n",
+      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"input_ids\"])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n",
+    "print(result[0][\"token_type_ids\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n",
+    "\n",
+    "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n",
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"offset_mapping\"][:20])\n",
+    "print(result[0][\"input_ids\"][:20])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n",
+    "\n",
+    "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n",
+      "overflow_to_sample:  0\n",
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n",
+      "overflow_to_sample:  0\n"
+     ]
+    }
+   ],
+   "source": [
+    "for res in result:\n",
+    "    tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n",
+    "    print(\"\".join(tokens))\n",
+    "    print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n",
+    "\n",
+    "基于以上信息,我们处理训练集的思路如下:\n",
+    "\n",
+    "1. 通过 `overflow_to_sample` 来获取原来的数据\n",
+    "2. 通过原数据的 `answers` 找到答案的起始位置\n",
+    "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n",
+    "\n",
+    "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n",
+      "处理后的训练集大小: 26198\n"
+     ]
+    }
+   ],
+   "source": [
+    "max_length = 256\n",
+    "doc_stride = 128\n",
+    "def _process_train(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        padding=\"max_length\",\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        # 获取 [CLS] 对应的位置\n",
+    "        input_ids = tokenized_data[\"input_ids\"]\n",
+    "        cls_index = input_ids.index(tokenizer.cls_token_id)\n",
+    "\n",
+    "        # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n",
+    "        # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n",
+    "        offsets = tokenized_data[\"offset_mapping\"]\n",
+    "        # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "\n",
+    "        # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n",
+    "        # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        answers = data[sample_index][\"answers\"]\n",
+    "\n",
+    "        # answers 和 answer_starts 均为长度为 1 的 list\n",
+    "        # 我们可以计算出答案的结束位置\n",
+    "        start_char = answers[\"answer_start\"][0]\n",
+    "        end_char = start_char + len(answers[\"text\"][0])\n",
+    "\n",
+    "        token_start_index = 0\n",
+    "        while token_type_ids[token_start_index] != 1:\n",
+    "            token_start_index += 1\n",
+    "\n",
+    "        token_end_index = len(input_ids) - 1\n",
+    "        while token_type_ids[token_end_index] != 1:\n",
+    "            token_end_index -= 1\n",
+    "        # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n",
+    "        token_end_index -= 1\n",
+    "\n",
+    "        if not (offsets[token_start_index][0] <= start_char and\n",
+    "                offsets[token_end_index][1] >= end_char):\n",
+    "            # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n",
+    "            tokenized_data_list[i][\"start_pos\"] = cls_index\n",
+    "            tokenized_data_list[i][\"end_pos\"] = cls_index\n",
+    "        else:\n",
+    "            # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n",
+    "            while token_start_index < len(offsets) and offsets[\n",
+    "                    token_start_index][0] <= start_char:\n",
+    "                token_start_index += 1\n",
+    "            tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n",
+    "            while offsets[token_end_index][1] >= end_char:\n",
+    "                token_end_index -= 1\n",
+    "            tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "train_dataset.map(_process_train, batched=True, num_workers=5)\n",
+    "print(train_dataset[0])\n",
+    "print(\"处理后的训练集大小:\", len(train_dataset))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.2 处理验证集\n",
+    "\n",
+    "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       ""
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def _process_val(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "        # 保存数据对应的 id\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        tokenized_data_list[i][\"example_id\"] = data[sample_index][\"id\"]\n",
+    "\n",
+    "        # 将不属于 context 的 offset 设置为 None\n",
+    "        tokenized_data_list[i][\"offset_mapping\"] = [\n",
+    "            (o if token_type_ids[k] == 1 else None)\n",
+    "            for k, o in enumerate(tokenized_data[\"offset_mapping\"])\n",
+    "        ]\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "val_dataset.map(_process_val, batched=True, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.3 DataLoader\n",
+    "\n",
+    "最后使用 `PaddleDataLoader` 将数据集包裹起来即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n", + "\n", + "#### 3.1 损失函数\n", + "\n", + "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CrossEntropyLossForSquad(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(CrossEntropyLossForSquad, self).__init__()\n", + "\n", + " def forward(self, start_logits, end_logits, start_pos, end_pos):\n", + " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n", + " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n", + " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=start_logits, label=start_pos)\n", + " start_loss = paddle.mean(start_loss)\n", + " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=end_logits, label=end_pos)\n", + " end_loss = paddle.mean(end_loss)\n", + "\n", + " loss = (start_loss + end_loss) / 2\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 定义模型\n", + "\n", + "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `fastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n", + "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n", + "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import ErnieForQuestionAnswering\n", + "\n", + "class QAModel(paddle.nn.Layer):\n", + " def __init__(self, model_checkpoint):\n", + " super(QAModel, self).__init__()\n", + " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n", + " self.loss_func = CrossEntropyLossForSquad()\n", + "\n", + " def forward(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self.model(input_ids, token_type_ids)\n", + " return start_logits, end_logits\n", + "\n", + " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n", + "\n", + "model = QAModel(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 自定义 Metric 进行数据的评估\n", + "\n", + "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n", + "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n", + "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n", + "\n", + "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `fastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n", + "\n", + "在初始化之外,一个 `Metric` 还需要实现三个函数:\n", + "\n", + "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n", + "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n", + "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n", + " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `fastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n", + "\n", + "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import Metric\n", + "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n", + "import contextlib\n", + "\n", + "class SquadEvaluateMetric(Metric):\n", + " def __init__(self, examples, features, testing=False):\n", + " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n", + " self.examples = examples\n", + " self.features = features\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + " self.testing = testing\n", + "\n", + " def reset(self):\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + "\n", + " def update(self, start_logits, end_logits):\n", + " for start, end in zip(start_logits, end_logits):\n", + " self.all_start_logits.append(start.numpy())\n", + " self.all_end_logits.append(end.numpy())\n", + "\n", + " def get_metric(self):\n", + " all_predictions, _, _ = compute_prediction(\n", + " self.examples, self.features[:len(self.all_start_logits)],\n", + " (self.all_start_logits, self.all_end_logits),\n", + " False, 20, 30\n", + " )\n", + " with contextlib.redirect_stdout(None):\n", + " result = squad_evaluate(\n", + " examples=self.examples,\n", + " preds=all_predictions,\n", + " is_whitespace_splited=False\n", + " )\n", + "\n", + " if self.testing:\n", + " self.print_predictions(all_predictions)\n", + " return result\n", + "\n", + " def print_predictions(self, preds):\n", + " for i, data in enumerate(self.examples):\n", + " if i >= 5:\n", + " break\n", + " print()\n", + " print(\"原文:\", data[\"context\"])\n", + " print(\"问题:\", data[\"question\"], \\\n", + " \"答案:\", preds[data[\"id\"]], \\\n", + " \"正确答案:\", data[\"answers\"][\"text\"])\n", + "\n", + "metric = SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.4 训练\n", + "\n", + "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `fastNLP` 吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[19:04:54] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 49.25899788285109,\n",
+       "  \"f1#squad\": 66.55559127349602,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 49.25899788285109,\n",
+       "  \"HasAns_f1#squad\": 66.55559127349602,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 57.37473535638673,\n",
+       "  \"f1#squad\": 70.93036525200617,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 57.37473535638673,\n",
+       "  \"HasAns_f1#squad\": 70.93036525200617,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 63.86732533521524,\n",
+       "  \"f1#squad\": 78.62546663568186,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 63.86732533521524,\n",
+       "  \"HasAns_f1#squad\": 78.62546663568186,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 64.92589978828511,\n",
+       "  \"f1#squad\": 79.36746074079691,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 64.92589978828511,\n",
+       "  \"HasAns_f1#squad\": 79.36746074079691,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.70218772053634,\n",
+       "  \"f1#squad\": 80.33295482054824,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.70218772053634,\n",
+       "  \"HasAns_f1#squad\": 80.33295482054824,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.41990119971771,\n",
+       "  \"f1#squad\": 79.7483487059053,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.41990119971771,\n",
+       "  \"HasAns_f1#squad\": 79.7483487059053,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 66.61961891319689,\n",
+       "  \"f1#squad\": 80.32432238994133,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 66.61961891319689,\n",
+       "  \"HasAns_f1#squad\": 80.32432238994133,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.84333098094567,\n",
+       "  \"f1#squad\": 79.23169801265415,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.84333098094567,\n",
+       "  \"HasAns_f1#squad\": 79.23169801265415,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[19:20:28] INFO     Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n",
+       "                    2022-06-27-19_00_15_388554/best_so_far                                   \n",
+       "                    with f1#squad: 80.33295482054824...                                      \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n",
+       "                    0_15_388554/best_so_far...                                               \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 1\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks=[\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " device=1,\n", + " optimizers=optimizer,\n", + " n_epochs=n_epochs,\n", + " callbacks=callbacks,\n", + " evaluate_every=100,\n", + " metrics={\"squad\": metric},\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.5 测试\n", + "\n", + "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
+       "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
+       "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
+       "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
+       "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
+       "行垫,油墨外露容易脱落。 \n",
+       "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
+       "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
+       "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
+       "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
+       "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
+       "10厘米。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
+       "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
+       "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
+       "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
+       "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
+       "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
+       "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
+       "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
+       "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
+       "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
+       "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
+       "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
+       "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
+       "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
+       "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
+       "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
+       "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
+       ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
+       "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
+       "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
+       "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
+       "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
+       "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
+       "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n",
+       "
\n" + ], + "text/plain": [ + "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{\n",
+       "    'exact#squad': 65.70218772053634,\n",
+       "    'f1#squad': 80.33295482054824,\n",
+       "    'total#squad': 1417,\n",
+       "    'HasAns_exact#squad': 65.70218772053634,\n",
+       "    'HasAns_f1#squad': 80.33295482054824,\n",
+       "    'HasAns_total#squad': 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=val_dataloader,\n", + " device=1,\n", + " metrics={\n", + " \"squad\": SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + " testing=True,\n", + " ),\n", + " },\n", + ")\n", + "result = evaluator.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/figures/E1-fig-glue-benchmark.png b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 00000000..515db700 Binary files /dev/null and b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png differ diff --git a/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png new file mode 100644 index 00000000..b5a9c1b8 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png differ diff --git a/docs/source/tutorials/figures/E2-fig-pet-model.png b/docs/source/tutorials/figures/E2-fig-pet-model.png new file mode 100644 index 00000000..c3c377c0 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-pet-model.png differ diff --git a/docs/source/tutorials/figures/T0-fig-parameter-matching.png b/docs/source/tutorials/figures/T0-fig-parameter-matching.png new file mode 100644 index 00000000..24013cc1 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-parameter-matching.png differ diff --git a/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png new file mode 100644 index 00000000..38222ee8 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png differ diff --git a/docs/source/tutorials/figures/T0-fig-training-structure.png b/docs/source/tutorials/figures/T0-fig-training-structure.png new file mode 100644 index 00000000..edc2e2ff Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-training-structure.png differ diff --git a/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png new file mode 100644 index 00000000..803cf34a Binary files /dev/null and b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png new file mode 100644 index 00000000..ff2519c4 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png new file mode 100644 index 00000000..ed003a2f Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png new file mode 100644 index 00000000..d45f65d8 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png new file mode 100644 index 00000000..f50ddb1c Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png differ diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 9885a175..31249c80 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -2,4 +2,4 @@ from fastNLP.envs import * from fastNLP.core import * -__version__ = '0.8.0beta' +__version__ = '1.0.0alpha' diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py index 5e235a0f..30d73e26 100644 --- a/fastNLP/core/collators/padders/oneflow_padder.py +++ b/fastNLP/core/collators/padders/oneflow_padder.py @@ -7,6 +7,7 @@ from inspect import isclass import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.envs.utils import _module_available if _NEED_IMPORT_ONEFLOW: import oneflow diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index c3c658b8..ac934bd7 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -83,13 +83,13 @@ class Trainer(TrainerEventTrigger): .. warning:: 当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 - 用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 - 你自身保证每张卡上所使用的数据是不同的。 + 用以训练的数据是不重叠的。如果您对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 + 您自身保证每张卡上所使用的数据是不同的。 :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, - 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 - 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 + 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是您可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 + 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时您也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); device 的可选输入如下所示: @@ -195,7 +195,7 @@ class Trainer(TrainerEventTrigger): 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错; 2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; - 注意该参数会被传进 ``Evaluator`` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); + 注意该参数会被传进 ``Evaluator`` 中;因此您可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); 如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。 :param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出: @@ -366,7 +366,7 @@ class Trainer(TrainerEventTrigger): .. note:: ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; - ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: + ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,您需要保证这几个参数得到正确的传入: 必须的参数:``metrics`` 与 ``evaluate_dataloaders``; @@ -896,7 +896,7 @@ class Trainer(TrainerEventTrigger): 这段代码意味着 ``fn1`` 和 ``fn2`` 会被加入到 ``trainer1``,``fn3`` 会被加入到 ``trainer2``; - 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; + 注意如果您使用该函数修饰器来为您的训练添加 callback,请务必保证您加入 callback 函数的代码在实例化 `Trainer` 之前; 补充性的解释见 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn`; diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 8512fcdb..53390409 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -584,7 +584,7 @@ class DataSet: 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` 中。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; :param field_name: 传入 ``func`` 的 field 名称; :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; @@ -624,10 +624,9 @@ class DataSet: ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; - :param field_name: 传入 ``func`` 的 fiel` 名称; - :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 - 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: @@ -751,8 +750,8 @@ class DataSet: 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 - :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py index 9359615a..0b1948e8 100644 --- a/fastNLP/core/drivers/torch_driver/torch_fsdp.py +++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py @@ -1,15 +1,17 @@ -from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12 +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12, _NEED_IMPORT_TORCH if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel + import os -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel from typing import Optional, Union, List, Dict, Mapping from pathlib import Path diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py index 12e7294c..6980c851 100644 --- a/fastNLP/embeddings/torch/static_embedding.py +++ b/fastNLP/embeddings/torch/static_embedding.py @@ -86,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): :param requires_grad: 是否需要梯度。 :param init_method: 如何初始化没有找到的值。可以使用 :mod:`torch.nn.init` 中的各种方法,传入的方法应该接受一个 tensor,并 inplace 地修改其值。 - :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 + :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果您的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个 vector 表示,则将 ``lower`` 设置为 ``False``。 :param dropout: 以多大的概率对 embedding 的表示进行 Dropout。0.1 即随机将 10% 的值置为 0。 :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练, diff --git a/fastNLP/transformers/__init__.py b/fastNLP/transformers/__init__.py index 3b375020..6b175b28 100644 --- a/fastNLP/transformers/__init__.py +++ b/fastNLP/transformers/__init__.py @@ -1,4 +1,3 @@ """ :mod:`transformers` 模块,包含了常用的预训练模型。 """ -import sphinx-multiversion \ No newline at end of file