From 2a31cf831fda900162c3c6d1eb002a0bbcce9d17 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Thu, 7 Jul 2022 11:29:57 +0000 Subject: [PATCH] =?UTF-8?q?=E5=85=B6=E5=AE=83=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../source/tutorials/fastnlp_tutorial_0.ipynb | 1352 +++++++++ .../source/tutorials/fastnlp_tutorial_1.ipynb | 1333 +++++++++ .../source/tutorials/fastnlp_tutorial_2.ipynb | 884 ++++++ .../source/tutorials/fastnlp_tutorial_3.ipynb | 621 ++++ .../source/tutorials/fastnlp_tutorial_4.ipynb | 2614 +++++++++++++++++ .../source/tutorials/fastnlp_tutorial_5.ipynb | 1242 ++++++++ .../source/tutorials/fastnlp_tutorial_6.ipynb | 1646 +++++++++++ .../tutorials/fastnlp_tutorial_e1.ipynb | 1280 ++++++++ .../tutorials/fastnlp_tutorial_e2.ipynb | 1082 +++++++ .../fastnlp_tutorial_paddle_e1.ipynb | 1086 +++++++ .../fastnlp_tutorial_paddle_e2.ipynb | 1510 ++++++++++ .../figures/E1-fig-glue-benchmark.png | Bin 0 -> 158817 bytes .../figures/E2-fig-p-tuning-v2-model.png | Bin 0 -> 50517 bytes .../tutorials/figures/E2-fig-pet-model.png | Bin 0 -> 57162 bytes .../figures/T0-fig-parameter-matching.png | Bin 0 -> 95584 bytes .../figures/T0-fig-trainer-and-evaluator.png | Bin 0 -> 71418 bytes .../figures/T0-fig-training-structure.png | Bin 0 -> 80296 bytes .../figures/T1-fig-dataset-and-vocabulary.png | Bin 0 -> 138905 bytes .../paddle-ernie-1.0-masking-levels.png | Bin 0 -> 59022 bytes .../figures/paddle-ernie-1.0-masking.png | Bin 0 -> 46898 bytes .../paddle-ernie-2.0-continual-pretrain.png | Bin 0 -> 128680 bytes .../figures/paddle-ernie-3.0-framework.png | Bin 0 -> 202018 bytes fastNLP/__init__.py | 2 +- .../core/collators/padders/oneflow_padder.py | 1 + fastNLP/core/controllers/trainer.py | 14 +- fastNLP/core/dataset/dataset.py | 11 +- .../core/drivers/torch_driver/torch_fsdp.py | 10 +- fastNLP/embeddings/torch/static_embedding.py | 2 +- fastNLP/transformers/__init__.py | 1 - 29 files changed, 14671 insertions(+), 20 deletions(-) create mode 100644 docs/source/tutorials/fastnlp_tutorial_0.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_1.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_2.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_3.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_4.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_5.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_6.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_e1.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_e2.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb create mode 100644 docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb create mode 100644 docs/source/tutorials/figures/E1-fig-glue-benchmark.png create mode 100644 docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png create mode 100644 docs/source/tutorials/figures/E2-fig-pet-model.png create mode 100644 docs/source/tutorials/figures/T0-fig-parameter-matching.png create mode 100644 docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png create mode 100644 docs/source/tutorials/figures/T0-fig-training-structure.png create mode 100644 docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png create mode 100644 docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png create mode 100644 docs/source/tutorials/figures/paddle-ernie-1.0-masking.png create mode 100644 docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png create mode 100644 docs/source/tutorials/figures/paddle-ernie-3.0-framework.png diff --git a/docs/source/tutorials/fastnlp_tutorial_0.ipynb b/docs/source/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..09667794 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + " 1 trainer 和 evaluator 的基本关系\n", + " \n", + " 1.1 trainer 和 evaluater 的初始化\n", + "\n", + " 1.2 driver 的含义与使用要求\n", + "\n", + " 1.3 trainer 内部初始化 evaluater\n", + "\n", + " 2 使用 fastNLP 搭建 argmax 模型\n", + "\n", + " 2.1 trainer_step 和 evaluator_step\n", + "\n", + " 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + " 2.3 示例:argmax 模型的搭建\n", + "\n", + " 3 使用 fastNLP 训练 argmax 模型\n", + " \n", + " 3.1 trainer 外部初始化的 evaluator\n", + "\n", + " 3.2 trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 1.0`中,`Trainer`模块和`Evaluator`模块分别表示 **“训练器”和“评测器”**\n", + "\n", + " 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 1.0`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + " 非常关键的问题在于**如何正确设置二者的 driver**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 1.0`中,**driver**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + " 例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 1.0`中,**Trainer 和 Evaluator 都依赖于具体的 driver 来完成整体的工作流程**\n", + "\n", + " 具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", + "\n", + "注:这里给出一条建议:**在同一脚本中**,**所有的** Trainer **和** Evaluator **使用的** driver **应当保持一致**\n", + "\n", + " 尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n", + "\n", + " 多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 1.0`中,如果在**初始化 Trainer 时**,**传入参数 evaluator_dataloaders 和 metrics **\n", + "\n", + " 则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. argmax 模型的搭建实例" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 trainer_step 和 evaluator_step\n", + "\n", + "在`fastNLP 1.0`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", + "\n", + " 添加`pytorch`要求的`forward`方法外,还需要添加 `train_step` 和 `evaluate_step` 这两个方法\n", + "\n", + "```python\n", + "class Model(torch.nn.Module):\n", + " def __init__(self):\n", + " super(Model, self).__init__()\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + " pass\n", + "\n", + " def forward(self, x):\n", + " pass\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}\n", + "```\n", + "***\n", + "在`fastNLP 1.0`中,**函数 train_step 是 Trainer 中参数 train_fn 的默认值**\n", + "\n", + " 由于,在`Trainer`训练时,**Trainer 通过参数 train_fn 对应的模型方法获得当前数据批次的损失值**\n", + "\n", + " 因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + " 如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 1.0`中,**Trainer 要求模型通过 train_step 来返回一个字典**,**满足如 {\"loss\": loss} 的形式**\n", + "\n", + " 此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n", + "\n", + "同样,在`fastNLP 1.0`中,**函数 evaluate_step 是 Evaluator 中参数 evaluate_fn 的默认值**\n", + "\n", + " 在`Evaluator`测试时,**Evaluator 通过参数 evaluate_fn 对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + " 从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + " 从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + "在`fastNLP 1.0`中,参数匹配涉及到两个方面,分别是在\n", + "\n", + " 一方面,**在模型的前向传播中**,**dataloader 向 train_step 或 evaluate_step 函数传递 batch**\n", + "\n", + " 另方面,**在模型的评测过程中**,**evaluate_dataloader 向 metric 的 update 函数传递 batch**\n", + "\n", + "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n", + "\n", + " **fastNLP 1.0 要求 dataloader 生成的每个 batch **,**满足如 {\"x\": x, \"y\": y} 的形式**\n", + "\n", + " 同时,`fastNLP 1.0`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n", + "\n", + " **字典形式的定义**,**对应在 Dataset 定义的 \\_\\_getitem\\_\\_ 方法中**,例如下方的`ArgMaxDatset`\n", + "\n", + " 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", + "\n", + " `fastNLP 1.0`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", + "\n", + "```python\n", + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return len(self.x)\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f5f1a6aa", + "metadata": {}, + "source": [ + "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n", + "\n", + " **update 函数**,**针对一个 batch 的预测结果**,计算其累计的评价指标\n", + "\n", + " **get_metric 函数**,**统计 update 函数累计的评价指标**,来计算最终的评价结果\n", + "\n", + " 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n", + "\n", + " 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n", + "\n", + " 在此基础上,**fastNLP 1.0 要求 evaluate_dataloader 生成的每个 batch 传递给对应的 metric**\n", + "\n", + " **以 {\"pred\": y_pred, \"target\": y_true} 的形式**,对应其`update`函数的函数签名\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 示例:argmax 模型的搭建\n", + "\n", + "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + " 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " pred = self.ac1(self.linear1(x))\n", + " pred = self.ac2(self.linear2(pred))\n", + " pred = self.output(pred)\n", + " return pred\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "71f3fa6b", + "metadata": {}, + "source": [ + " 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n", + "\n", + " 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n", + "\n", + " 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDataset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + " 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n", + "\n", + " 再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n", + "\n", + "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + " 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "363b5b09", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "c8d4443f", + "metadata": {}, + "source": [ + " 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 fastNLP 1.0 训练 argmax 模型\n", + "\n", + "### 3.1 trainer 外部初始化的 evaluator" + ] + }, + { + "cell_type": "markdown", + "id": "55145553", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n", + "\n", + " 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n", + "\n", + " 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", + "\n", + " 但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", + "\n", + " 通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 设定迭代轮数 \n", + " progress_bar=\"auto\" # 设定进度条格式\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6e202d6e", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + " 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", + "\n", + " `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ba047ead", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n", + "\n", + " 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n", + "\n", + " 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n", + "\n", + " 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8157bb9b", + "metadata": {}, + "source": [ + "通过使用`Evaluator`类的`run`函数,进行训练\n", + "\n", + " 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", + "\n", + " 最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "### 3.2 trainer 内部初始化的 evaluator \n", + "\n", + "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", + "\n", + " 通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", + "\n", + " 但是中间的评估结果仍会保留;**通过 evaluate_every 设定评估频率**,可以为负数、正数或者函数:\n", + "\n", + " **为负数时**,**表示每隔几个 epoch 评估一次**;**为正数时**,**则表示每隔几个 batch 评估一次**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, \n", + " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + " 还可以通过**参数 num_eval_sanity_batch 决定每次训练前运行多少个 evaluate_batch 进行评测**,**默认为 2 **\n", + "\n", + " 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[18:28:25] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.31,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 31.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.33,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 33.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.34,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 34.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.37,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 37.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.4,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 40.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4e9c619", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bc7cb4a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_1.ipynb b/docs/source/tutorials/fastnlp_tutorial_1.ipynb new file mode 100644 index 00000000..cff81a21 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_1.ipynb @@ -0,0 +1,1333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cdc25fcd", + "metadata": {}, + "source": [ + "# T1. dataset 和 vocabulary 的基本使用\n", + "\n", + " 1 dataset 的使用与结构\n", + " \n", + " 1.1 dataset 的结构与创建\n", + "\n", + " 1.2 dataset 的数据预处理\n", + "\n", + " 1.3 延伸:instance 和 field\n", + "\n", + " 2 vocabulary 的结构与使用\n", + "\n", + " 2.1 vocabulary 的创建与修改\n", + "\n", + " 2.2 vocabulary 与 OOV 问题\n", + "\n", + " 3 dataset 和 vocabulary 的组合使用\n", + " \n", + " 3.1 从 dataframe 中加载 dataset\n", + "\n", + " 3.2 从 dataset 中获取 vocabulary" + ] + }, + { + "cell_type": "markdown", + "id": "0eb18a22", + "metadata": {}, + "source": [ + "## 1. dataset 的基本使用\n", + "\n", + "### 1.1 dataset 的结构与创建\n", + "\n", + "在`fastNLP 1.0`中,使用`DataSet`模块表示数据集,**dataset 类似于关系型数据库中的数据表**(下文统一为小写 `dataset`)\n", + "\n", + " **主要包含 field 字段和 instance 实例两个元素**,对应 table 中的 field 字段和`record`记录\n", + "\n", + "在`fastNLP 1.0`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n", + "\n", + " 初始化方法,即将字典形式的表格 **{'field1': column1, 'field2': column2, ...}** 传入构造函数" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a1d69ad2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n", + " 'words': [['This', 'is', 'an', 'apple', '.'], \n", + " ['I', 'like', 'apples', '.'], \n", + " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n", + " 'num': [5, 4, 7]}\n", + "\n", + "dataset = DataSet(data)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9260fdc6", + "metadata": {}, + "source": [ + " 在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d72ef00", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------+--------------------+------------------------+------+\n", + "| 序号 | 句子 | 字符 | 长度 |\n", + "+------+--------------------+------------------------+------+\n", + "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n", + "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n", + "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n", + "+------+--------------------+------------------------+------+\n" + ] + } + ], + "source": [ + "temp = {'序号': [0, 1, 2], \n", + " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n", + " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n", + " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n", + " ['才', '能', '到', '达', '彼', '岸', '。']],\n", + " '长度': [7, 9, 7]}\n", + "\n", + "chinese = DataSet(temp)\n", + "print(chinese)" + ] + }, + { + "cell_type": "markdown", + "id": "202e5490", + "metadata": {}, + "source": [ + "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n", + "\n", + " 注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "09b478f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2492313174344 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "aa277674", + "metadata": {}, + "source": [ + " 注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n", + "\n", + " 如下所示,**dropped 和 dataset 具有相同 id**,**对 dropped 执行删除操作 dataset 同时会被修改**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77c8583a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2491986424200 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped.drop(lambda ins:ins['num'] < 5)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a76199dc", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8824b40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+--------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "+-----+--------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.delete_instance(2)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "f4fa9f33", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f68ddb40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+--------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "+-----+--------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset.delete_field('num')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "b1e9d42c", + "metadata": {}, + "source": [ + "### 1.2 dataset 的数据预处理\n", + "\n", + "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", + "\n", + " **apply 和 apply_more 输入整条实例**,**apply_field 和 apply_field_more 仅输入实例的部分字段**\n", + "\n", + " **apply 和 apply_field 仅输出单个字段**,**apply_more 和 apply_field_more 则是输出多个字段**\n", + "\n", + " **apply 和 apply_field 返回的是个列表**,**apply_more 和 apply_field_more 返回的是个字典**\n", + "\n", + " 预处理过程中,通过`progress_bar`参数设置显示进度条类型,通过`num_proc`设置多进程\n", + "***\n", + "\n", + "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n", + "\n", + " 的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72a0b5f9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"], }\n", + "dataset = DataSet(data)\n", + "dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', progress_bar=\"tqdm\") #\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "c10275ee", + "metadata": {}, + "source": [ + " **apply 使用的函数可以是一个基于 lambda 表达式的匿名函数**,**也可以是一个自定义的函数**" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b1a8631f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "\n", + "def get_words(instance):\n", + " sentence = instance['sentence']\n", + " words = sentence.split()\n", + " return words\n", + "\n", + "dataset.apply(get_words, new_field_name='words', progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "64abf745", + "metadata": {}, + "source": [ + "`apply_field`的参数,除了函数`func`外还有`field_name`和`new_field_name`,该函数`func`的处理对象仅\n", + "\n", + " 是`dataset`模块中的每个`field_name`对应的字段内容,处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "057c1d2c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', \n", + " progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "5a9cc8b2", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "51e2f02c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].split(), 'num': len(ins['sentence'].split())}, \n", + " progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "02d2b7ef", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "db4295d5", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field_more(lambda sent:{'words': sent.split(), 'num': len(sent.split())}, \n", + " field_name='sentence', progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9c09e592", + "metadata": {}, + "source": [ + "### 1.3 延伸:instance 和 field\n", + "\n", + "在`fastNLP 1.0`中,使用`Instance`模块表示数据集`dataset`中的每条数据,被称为实例\n", + "\n", + " 构造方式类似于构造一个字典,通过键值相同的`Instance`列表,也可以初始化一个`dataset`,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "012f537c", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import DataSet\n", + "from fastNLP import Instance\n", + "\n", + "dataset = DataSet([\n", + " Instance(sentence=\"This is an apple .\",\n", + " words=['This', 'is', 'an', 'apple', '.'],\n", + " num=5),\n", + " Instance(sentence=\"I like apples .\",\n", + " words=['I', 'like', 'apples', '.'],\n", + " num=4),\n", + " Instance(sentence=\"Apples are good for our health .\",\n", + " words=['Apples', 'are', 'good', 'for', 'our', 'health', '.'],\n", + " num=7),\n", + " ])" + ] + }, + { + "cell_type": "markdown", + "id": "2fafb1ef", + "metadata": {}, + "source": [ + " 通过`items`、`keys`和`values`方法,可以分别获得`dataset`的`item`列表、`key`列表、`value`列表" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a4c1c10d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_items([('sentence', 'This is an apple .'), ('words', ['This', 'is', 'an', 'apple', '.']), ('num', 5)])\n", + "dict_keys(['sentence', 'words', 'num'])\n", + "dict_values(['This is an apple .', ['This', 'is', 'an', 'apple', '.'], 5])\n" + ] + } + ], + "source": [ + "ins = Instance(sentence=\"This is an apple .\", words=['This', 'is', 'an', 'apple', '.'], num=5)\n", + "\n", + "print(ins.items())\n", + "print(ins.keys())\n", + "print(ins.values())" + ] + }, + { + "cell_type": "markdown", + "id": "b5459a2d", + "metadata": {}, + "source": [ + " 通过`add_field`方法,可以在`Instance`实例中,通过参数`field_name`添加字段,通过参数`field`赋值" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "55376402", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+------------------------+-----+-----+\n", + "| sentence | words | num | idx |\n", + "+--------------------+------------------------+-----+-----+\n", + "| This is an apple . | ['This', 'is', 'an'... | 5 | 0 |\n", + "+--------------------+------------------------+-----+-----+\n" + ] + } + ], + "source": [ + "ins.add_field(field_name='idx', field=0)\n", + "print(ins)" + ] + }, + { + "cell_type": "markdown", + "id": "49caaa9c", + "metadata": {}, + "source": [ + "在`fastNLP 1.0`中,使用`FieldArray`模块表示数据集`dataset`中的每条字段名(注:没有`field`类)\n", + "\n", + " 通过`get_all_fields`方法可以获取`dataset`的字段列表\n", + "\n", + " 通过`get_field_names`方法可以获取`dataset`的字段名称列表,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fe15f4c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentence':
\n", + " | SentenceId | \n", + "Sentence | \n", + "Sentiment | \n", + "
---|---|---|---|
0 | \n", + "1 | \n", + "A series of escapades demonstrating the adage ... | \n", + "negative | \n", + "
1 | \n", + "2 | \n", + "This quiet , introspective and entertaining in... | \n", + "positive | \n", + "
2 | \n", + "3 | \n", + "Even fans of Ismail Merchant 's work , I suspe... | \n", + "negative | \n", + "
3 | \n", + "4 | \n", + "A positively thrilling combination of ethnogra... | \n", + "neutral | \n", + "
4 | \n", + "5 | \n", + "A comedy-drama of nearly epic proportions root... | \n", + "positive | \n", + "
5 | \n", + "6 | \n", + "The Importance of Being Earnest , so thick wit... | \n", + "neutral | \n", + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------------------------------------+----------+\n", + "| text | label |\n", + "+------------------------------------------+----------+\n", + "| ['a', 'series', 'of', 'escapades', 'd... | negative |\n", + "| ['this', 'quiet', ',', 'introspective... | positive |\n", + "| ['even', 'fans', 'of', 'ismail', 'mer... | negative |\n", + "| ['the', 'importance', 'of', 'being', ... | neutral |\n", + "+------------------------------------------+----------+\n", + "+------------------------------------------+----------+\n", + "| text | label |\n", + "+------------------------------------------+----------+\n", + "| ['a', 'comedy-drama', 'of', 'nearly',... | positive |\n", + "| ['a', 'positively', 'thrilling', 'com... | neutral |\n", + "+------------------------------------------+----------+\n", + "{'
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n", + "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import pandas as pd\n", + "from functools import partial\n", + "from fastNLP.transformers.torch import BertTokenizer\n", + "\n", + "from fastNLP import DataSet\n", + "from fastNLP import Vocabulary\n", + "from fastNLP.io import DataBundle\n", + "\n", + "\n", + "class PipeDemo:\n", + " def __init__(self, tokenizer='bert-base-uncased'):\n", + " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n", + "\n", + " def process_from_file(self, path='./data/test4dataset.tsv'):\n", + " datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n", + " train_ds, test_ds = datasets.split(ratio=0.7)\n", + " train_ds, dev_ds = datasets.split(ratio=0.8)\n", + " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n", + "\n", + " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n", + " return_attention_mask=True)\n", + " data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n", + " \n", + " target_vocab = Vocabulary(padding=None, unknown=None)\n", + "\n", + " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n", + " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n", + " new_field_name='target')\n", + "\n", + " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n", + " data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n", + " return data_bundle\n", + "\n", + " \n", + "pipe = PipeDemo(tokenizer='bert-base-uncased')\n", + "\n", + "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n", + "\n", + "print(data_bundle.get_dataset('train'))" + ] + }, + { + "cell_type": "markdown", + "id": "76e6b8ab", + "metadata": {}, + "source": [ + "### 1.2 dataloader 的函数创建\n", + "\n", + "在`fastNLP 1.0`中,**更方便、可能更常用的 dataloader 创建方法是通过 prepare_xx_dataloader 函数**\n", + "\n", + " 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n", + "\n", + " 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n", + "\n", + "同时我们看还可以发现,在`fastNLP 1.0`中,**batch 表示为字典 dict 类型**,**key 值就是原先数据集中各个字段**\n", + "\n", + " **除去经过 DataBundle.set_ignore 函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5fd60e42", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "96380c67", + "metadata": {}, + "source": [ + " 然后,使用`tutorial-3`中的知识,**通过 prepare_torch_dataloader 处理数据集得到 dataloader**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b9dd1273", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "eb75aaba", + "metadata": {}, + "source": [ + "模型使用方面,这里使用`Embedding`、`LSTM`、`MLP`等模块搭建模型,方法类似`pytorch`,结构如下所示\n", + "\n", + "```\n", + "ClsByModules(\n", + " (embedding): Embedding(\n", + " (embed): Embedding(8458, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n", + " )\n", + " (mlp): MLP(\n", + " (hiddens): ModuleList()\n", + " (output): Linear(in_features=128, out_features=2, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (loss_fn): CrossEntropyLoss()\n", + ")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0b25b25c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from fastNLP.modules.torch import LSTM, MLP\n", + "from fastNLP.embeddings.torch import Embedding\n", + "\n", + "\n", + "class ClsByModules(nn.Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " nn.Module.__init__(self)\n", + "\n", + " self.embedding = Embedding((vocab_size, embedding_dim))\n", + " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", + " self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", + " \n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "markdown", + "id": "4890de5a", + "metadata": {}, + "source": [ + " 接着,初始化模型`model`实例,同时,使用`torch.optim.AdamW`初始化`optimizer`实例" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9dbbf50d", + "metadata": {}, + "outputs": [], + "source": [ + "model = ClsByModules(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "id": "054538f5", + "metadata": {}, + "source": [ + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7a93432f", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "31102e0f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.525,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 84.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.54375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 87.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.55,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 88.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 100.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.65,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 104.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.69375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 111.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.66875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 107.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.68125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 109.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bc4bfb2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "07538876", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b52eafd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "383" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "del dataset\n", + "del sst2data\n", + "\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "d9443213", + "metadata": {}, + "source": [ + "## 2. fastNLP 中 models 模块的介绍\n", + "\n", + "### 2.1 示例一:models 实现 CNN 分类\n", + "\n", + " 本示例使用`fastNLP 1.0`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", + "\n", + "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n", + "\n", + "模型使用方面,如上所述,这里使用**基于卷积神经网络 CNN 的预定义文本分类模型 CNNText**,结构如下所示\n", + "\n", + " 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", + "\n", + " **感受野为 1 、 3 、 5 的卷积算子变换至 30 维、 40 维、 50 维的卷积特征**,再将三者拼接\n", + "\n", + " 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n", + "\n", + "```\n", + "CNNText(\n", + " (embed): Embedding(\n", + " (embed): Embedding(5194, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (conv_pool): ConvMaxpool(\n", + " (convs): ModuleList(\n", + " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", + " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", + " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=120, out_features=2, bias=True)\n", + ")\n", + "```\n", + "\n", + "对应到代码上,**从 fastNLP.models.torch 路径下导入 CNNText**,初始化`CNNText`和`optimizer`实例\n", + "\n", + " 注意:初始化`CNNText`时,**二元组参数 embed 、分类数量 num_classes 是必须传入的**,其中\n", + "\n", + " **embed 表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100` 维" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f6e76e2e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "0cc5ca10", + "metadata": {}, + "source": [ + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "50a13ee5", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "28903a7d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:21:57] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.654444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 589.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.767778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 691.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.797778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 718.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.803333,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 723.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.807778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 727.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.812222,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 731.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.804444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 724.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.806667,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 726.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "5b5c0446", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e9e70f88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "344" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "6aec2a19", + "metadata": {}, + "source": [ + "### 2.2 示例二:models 实现 BiLSTM 标注\n", + "\n", + " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n", + "\n", + " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n", + "\n", + " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n", + "\n", + "模型使用方面,如上所述,这里使用**基于双向 LSTM +条件随机场 CRF 的标注模型 BiLSTMCRF**,结构如下所示\n", + "\n", + " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n", + "\n", + "```\n", + "BiLSTMCRF(\n", + " (embed): Embedding(7590, 100)\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=200, out_features=9, bias=True)\n", + " (crf): ConditionalRandomField()\n", + ")\n", + "```\n", + "\n", + "数据使用方面,此处仍然**使用 datasets 模块中的 load_dataset 函数**,以如下形式,加载`CoNLL-2003`数据集\n", + "\n", + " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "03e66686", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "593bc03ed5914953ab94268ff2f01710", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ner2data = load_dataset('conll2003', 'conll2003')" + ] + }, + { + "cell_type": "markdown", + "id": "fc505631", + "metadata": {}, + "source": [ + "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", + "\n", + " 完成数据集格式调整、文本序列化等操作;此处**需要 'words' 、 'seq_len' 、 'target' 三个字段**\n", + "\n", + "此外,**需要定义 NER 标签到标签序号的映射**(**词汇表 label_vocab**),数据集中标签已经完成了序号映射\n", + "\n", + " 所以需要人工定义**9 个标签对应之前的 9 个分类目标**;数据集说明中规定,`'O'`表示其他标签\n", + "\n", + " **后缀 '-PER' 、 '-ORG' 、 '-LOC' 、 '-MISC' 对应人名、组织名、地名、时间等其他命名**\n", + "\n", + " **前缀 'B-' 表示起始标签、 'I-' 表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1f88cad4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('tokens')\n", + "dataset.delete_field('ner_tags')\n", + "dataset.delete_field('pos_tags')\n", + "dataset.delete_field('chunk_tags')\n", + "dataset.delete_field('id')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "token_vocab = Vocabulary()\n", + "token_vocab.from_dataset(dataset, field_name='words')\n", + "token_vocab.index_dataset(dataset, field_name='words')\n", + "label_vocab = Vocabulary(padding=None, unknown=None)\n", + "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "d9889427", + "metadata": {}, + "source": [ + "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7802a072", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "2bc7831b", + "metadata": {}, + "source": [ + "接着,**从 fastNLP.models.torch 路径下导入 BiLSTMCRF**,初始化`BiLSTMCRF`实例和优化器\n", + "\n", + " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数 embed 、 num_classes 是必须传入的**\n", + "\n", + " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4e12c09f", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import BiLSTMCRF\n", + "\n", + "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n", + " num_layers=1, hidden_size=150, dropout=0.2)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "bf30608f", + "metadata": {}, + "source": [ + "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n", + "\n", + " **使用 SpanFPreRecMetric 作为 NER 的评价标准**,详细请参考接下来的`tutorial-5`\n", + "\n", + " 同时,**初始化时需要添加 vocabulary 形式的标签与序号之间的映射 tag_vocab**" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "cbd6c205", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, SpanFPreRecMetric\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "0f8eff34", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:23:41] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.169014,\n", + " \"pre#F1\": 0.170732,\n", + " \"rec#F1\": 0.167331\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.361809,\n", + " \"pre#F1\": 0.312139,\n", + " \"rec#F1\": 0.430279\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.525,\n", + " \"pre#F1\": 0.475728,\n", + " \"rec#F1\": 0.585657\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.627306,\n", + " \"pre#F1\": 0.584192,\n", + " \"rec#F1\": 0.677291\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.710937,\n", + " \"pre#F1\": 0.697318,\n", + " \"rec#F1\": 0.7251\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.739563,\n", + " \"pre#F1\": 0.738095,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.748491,\n", + " \"pre#F1\": 0.756098,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.716763,\n", + " \"pre#F1\": 0.69403,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.768293,\n", + " \"pre#F1\": 0.784232,\n", + " \"rec#F1\": 0.752988\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.757692,\n", + " \"pre#F1\": 0.732342,\n", + " \"rec#F1\": 0.784861\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96bae094", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_5.ipynb b/docs/source/tutorials/fastnlp_tutorial_5.ipynb new file mode 100644 index 00000000..ab759feb --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_5.ipynb @@ -0,0 +1,1242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T5. trainer 和 evaluator 的深入介绍\n", + "\n", + " 1 fastNLP 中 driver 的补充介绍\n", + " \n", + " 1.1 trainer 和 driver 的构想 \n", + "\n", + " 1.2 device 与 多卡训练\n", + "\n", + " 2 fastNLP 中的更多 metric 类型\n", + "\n", + " 2.1 预定义的 metric 类型\n", + "\n", + " 2.2 自定义的 metric 类型\n", + "\n", + " 3 fastNLP 中 trainer 的补充介绍\n", + "\n", + " 3.1 trainer 的内部结构" + ] + }, + { + "cell_type": "markdown", + "id": "08752c5a", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## 1. fastNLP 中 driver 的补充介绍\n", + "\n", + "### 1.1 trainer 和 driver 的构想\n", + "\n", + "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n", + "\n", + " 在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n", + "\n", + " **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n", + "\n", + "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n", + "\n", + " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", + "\n", + " 划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n", + "\n", + " 以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", + "\n", + "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n", + "|----|----|----|\n", + "| try: | try: | |\n", + "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: | |\n", + "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", + "| batch = fetch_batch() | batch = fetch_batch() | |\n", + "| loss = model.forward(batch) | | loss = model.forward(batch) |\n", + "| loss.backward() | | loss.backward() |\n", + "| model.clear_grad() | | model.clear_grad() |\n", + "| model.update() | | model.update() |\n", + "| if need_save: | if need_save: | |\n", + "| model.save() | | model.save() |\n", + "| except: | except: | |\n", + "| process_exception() | process_exception() | |" + ] + }, + { + "cell_type": "markdown", + "id": "3e55f07b", + "metadata": {}, + "source": [ + " 对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n", + "\n", + " 划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n", + "\n", + " 以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", + "\n", + "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n", + "|----|----|----|\n", + "| try: | try: | |\n", + "| model.set_eval() | model.set_eval() | |\n", + "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", + "| batch = fetch_batch() | batch = fetch_batch() | |\n", + "| outputs = model.evaluate(batch) | | outputs = model.evaluate(batch) |\n", + "| metric.compute(batch, outputs) | | metric.compute(batch, outputs) |\n", + "| results = metric.get_metric() | results = metric.get_metric() | |\n", + "| except: | except: | |\n", + "| process_exception() | process_exception() | |" + ] + }, + { + "cell_type": "markdown", + "id": "94ba11c6", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n", + "\n", + " **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n", + "\n", + " 而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n", + "\n", + " 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n", + "\n", + " 对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`" + ] + }, + { + "cell_type": "markdown", + "id": "ab1cea7d", + "metadata": {}, + "source": [ + "### 1.2 device 与 多卡训练\n", + "\n", + "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n", + "\n", + " 由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n", + "\n", + " 数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n", + "\n", + " 模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n", + "\n", + " 例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n", + "\n", + " 指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n", + "\n", + " 在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n", + " \n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # model 基于 pytorch 实现 \n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver='torch', # driver 使用 torch_driver \n", + " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " ...\n", + " )\n", + "\n", + "class Accuracy(Metric):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.register_element(name='total', value=0, aggregate_method='sum')\n", + " self.register_element(name='right', value=0, aggregate_method='sum')\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "id": "e2e0a210", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示" + ] + }, + { + "cell_type": "markdown", + "id": "8d19220c", + "metadata": {}, + "source": [ + "## 2. fastNLP 中的更多 metric 类型\n", + "\n", + "### 2.1 预定义的 metric 类型\n", + "\n", + "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n", + "\n", + " 包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", + "\n", + " **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n", + "\n", + " **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n", + "\n", + "代码名称|简要介绍|代码路径\n", + "----|----|----|\n", + " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", + " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", + " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", + " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", + " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "fdc083a3", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + " 如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n", + "\n", + " **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n", + "\n", + "\n", + "### 2.1.1 Accuracy 与 TransformersAccuracy\n", + "\n", + "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n", + "\n", + " `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n", + "\n", + " 一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n", + "\n", + " **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n", + "\n", + "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n", + "\n", + " 在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n", + "\n", + "\n", + "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n", + "\n", + "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n", + "\n", + " 两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n", + "\n", + " `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n", + "\n", + " 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n", + "\n", + "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n", + "\n", + "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n", + "\n", + " **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n", + "\n", + " **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n", + "\n", + " **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n", + "\n", + " **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n", + "\n", + "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n", + "\n", + " **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n", + "\n", + " 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n", + "\n", + " 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n", + "\n", + " 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n", + "\n", + " 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n", + "\n", + " 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n", + "\n", + "```python\n", + "from fastNLP import Vocabulary\n", + "from fastNLP import ClassifyFPreRecMetric\n", + "\n", + "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n", + "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n", + " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n", + " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n", + " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n", + " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n", + "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n", + "\n", + "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n", + " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n", + " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n", + " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n", + "metrics = {'F1': FPreRec}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "8a22f522", + "metadata": {}, + "source": [ + "### 2.2 自定义的 metric 类型\n", + "\n", + "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n", + "\n", + " 也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n", + "\n", + " 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n", + "\n", + " 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n", + "\n", + " **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n", + "\n", + " 在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", + "\n", + " 此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n", + "\n", + " 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", + "\n", + " 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n", + "\n", + "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08a872e9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Metric\n", + "\n", + "class MyMetric(Metric):\n", + "\n", + " def __init__(self):\n", + " Metric.__init__(self)\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + "\n", + " def update(self, pred, target):\n", + " self.total_num += target.size(0)\n", + " self.right_num += target.eq(pred).sum().item()\n", + "\n", + " def get_metric(self, reset=True):\n", + " acc = self.right_num / self.total_num\n", + " if reset:\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + " return {'prefix': acc}" + ] + }, + { + "cell_type": "markdown", + "id": "0155f447", + "metadata": {}, + "source": [ + " 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "sst2data = load_dataset('glue', 'sst2')" + ] + }, + { + "cell_type": "markdown", + "id": "e9d81760", + "metadata": {}, + "source": [ + " 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n", + "\n", + " 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "cfb28b1b", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "\n", + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "af3f8c63", + "metadata": {}, + "source": [ + " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2fd210c5", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "6e723b87", + "metadata": {}, + "source": [ + "## 3. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 3.1 trainer 的内部结构\n", + "\n", + "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n", + "\n", + " 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n", + "\n", + "\n", + "名称|参数|属性|功能|内容\n", + "----|----|----|----|----|\n", + "| **model** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n", + "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n", + "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n", + "| **driver** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", + "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", + "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n", + "| **optimizers** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n", + "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n", + "| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n", + "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n", + "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n", + "| **train_dataloader** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n", + "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n", + "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n", + "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n", + "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n", + "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n", + "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n", + "| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n", + "| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n", + "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n", + "| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n", + "| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n", + "| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n", + "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |\n", + "\n", + "其中,**input_mapping 和 output_mapping** 定义形式如下:输入字典形式的数据,根据参数匹配要求调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "de96c1d1", + "metadata": {}, + "outputs": [], + "source": [ + "def input_mapping(data):\n", + " data['target'] = data['label']\n", + " return data" + ] + }, + { + "cell_type": "markdown", + "id": "2fc8b9f3", + "metadata": {}, + "source": [ + " 而`trainer`模块的基础方法列表如下,相关进阶操作,如`on`系列函数、`callback`控制,请参考后续的`tutorial-7`\n", + "\n", + "|名称|功能|主要参数|\n", + "|----|----|----|\n", + "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n", + "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n", + "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n", + "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n", + "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n", + "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n", + "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n", + "| `save_checkpoint` | 保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n", + "| `load_checkpoint` | 加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` `resume_training`指明是否只精确到上次训练的批量,默认`True` |\n", + "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n", + "| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "1e21df35", + "metadata": {}, + "source": [ + "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", + "\n", + " 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "926a9c50", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " input_mapping=input_mapping,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'suffix': MyMetric()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b1b2e8b7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", + "\n", + "|名称|功能|默认值|\n", + "|----|----|----|\n", + "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", + "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", + "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", + "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", + "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.6875\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.825\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1abfa0a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_6.ipynb b/docs/source/tutorials/fastnlp_tutorial_6.ipynb new file mode 100644 index 00000000..63f7481e --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_6.ipynb @@ -0,0 +1,1646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T6. fastNLP 与 paddle 或 jittor 的结合\n", + "\n", + " 1 fastNLP 结合 paddle 训练模型\n", + " \n", + " 1.1 关于 paddle 的简单介绍\n", + "\n", + " 1.2 使用 paddle 搭建并训练模型\n", + "\n", + " 2 fastNLP 结合 jittor 训练模型\n", + "\n", + " 2.1 关于 jittor 的简单介绍\n", + "\n", + " 2.2 使用 jittor 搭建并训练模型\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08752c5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "sst2data = load_dataset('glue', 'sst2')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7e8cc210", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 21:01:38.510813 72 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
[21:03:08] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.78125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.80625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 129.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10) " + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0b3c", + "metadata": {}, + "source": [ + "## 2. fastNLP 结合 jittor 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c600191d", + "metadata": {}, + "outputs": [], + "source": [ + "import jittor\n", + "import jittor.nn as nn\n", + "\n", + "from jittor import Module\n", + "\n", + "\n", + "class ClsByJittor(Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " Module.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", + " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", + "\n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def execute(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a94ed8c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6d15ebc1", + "metadata": {}, + "outputs": [], + "source": [ + "from jittor.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "95d8d09e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_jittor_dataloader\n", + "\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "917eab81", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='jittor',\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7c4ac5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta: 0s \n", + "\n", + "Compiling Operators(31/31) used: 7.31s eta: 0s \n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.61875,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 99\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 112\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.725,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 116\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.73125,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 117\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb new file mode 100644 index 00000000..af8e60a0 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 从这篇开始,我们将开启 **fastNLP v1.0 tutorial 的 example 系列**,在接下来的\n", + "\n", + " 每篇`tutorial`里,我们将会介绍`fastNLP v1.0`在自然语言处理任务上的应用实例" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", + "\n", + " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", + "\n", + " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", + "\n", + "**GLUE**,**全称 General Language Understanding Evaluation**,**通用语言理解评估**,\n", + "\n", + " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + " **CoLA**,文本分类任务,预测单句语法正误分类;**SST-2**,文本分类任务,预测单句情感二分类\n", + "\n", + " **MRPC**,句对分类任务,预测句对语义一致性;**STS-B**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + " **QQP**,句对分类任务,预测问题对语义一致性;**MNLI**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + " **QNLI / RTE / WNLI**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + " 此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**SST**,**全称`Stanford Sentiment Treebank**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + " 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", + "\n", + " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5915debacf9443986b5b3b34870b303", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset('glue', task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: hide new secretions from the parental units \n" + ] + } + ], + "source": [ + "task_to_keys = {\n", + " 'cola': ('sentence', None),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mrpc': ('sentence1', 'sentence2'),\n", + " 'qnli': ('question', 'sentence'),\n", + " 'qqp': ('question1', 'question2'),\n", + " 'rte': ('sentence1', 'sentence2'),\n", + " 'sst2': ('sentence', None),\n", + " 'stsb': ('sentence1', 'sentence2'),\n", + " 'wnli': ('sentence1', 'sentence2'),\n", + "}\n", + "\n", + "sentence1_key, sentence2_key = task_to_keys[task]\n", + "\n", + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n", + "\n", + " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n", + "\n", + " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n", + "\n", + "此处的`tokenizer`和`SequenceClassificationModel`都是基于**distilbert-base-uncased 模型**\n", + "\n", + " 即使用较小的、不区分大小写的数据集,**对 bert-base 进行知识蒸馏后的版本**,结构上\n", + "\n", + " 包含**1个编码层**、**6个自注意力层**,**参数量`66M**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n", + "\n", + "首先,通过从`transformers`库中导入 **AutoTokenizer 模块**,**使用 from_pretrained 函数初始化**\n", + "\n", + " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n", + "\n", + " 需要注意的是,处理后返回的两个键值,**'input_ids'**表示原始文本对应的词素编号序列\n", + "\n", + " **'attention_mask'**表示自注意力运算时的掩码(标上`0`的部分对应`padding`的内容" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "model_checkpoint = 'distilbert-base-uncased'\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,定义预处理函数,**通过 dataset.map 方法**,**将数据集中的文本**,**替换为词素编号序列**" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " if sentence2_key is None:\n", + " return tokenizer(examples[sentence1_key], truncation=True)\n", + " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n", + "\n", + " 其中,**\\_\\_getitem\\_\\_ 函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", + "\n", + " 例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", + "\n", + " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " Dataset.__init__(self)\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item['input_ids'], item['attention_mask'], [item['label']] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,**定义校对函数 collate_fn 对齐同个 batch 内的每笔数据**,需要注意的是该函数的\n", + "\n", + " **返回值必须是字典**,**键值必须同待训练模型的 train_step 和 evaluate_step 函数的参数**\n", + "\n", + " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v1.0`的第一条**参数匹配**机制" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n", + "\n", + " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n", + "\n", + " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n", + "\n", + "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n", + "\n", + " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n", + "\n", + "需要注意的是**AutoModelForSequenceClassification 模块的输入参数和输出结构**\n", + "\n", + " 一方面,可以**通过输入标签值 labels**,**使用模块内的损失函数计算损失 loss**\n", + "\n", + " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n", + "\n", + " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率 logits**\n", + "\n", + " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n", + "\n", + " 同样需要注意,函数的返回值体现了`fastNLP v1.0`的第二条**参数匹配**机制" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " output = self.back_bone(input_ids=input_ids, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return output\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n", + "\n", + "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n", + "\n", + " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.9,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 288.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.88125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 282.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 280.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.865625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 277.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 附:`DistilBertForSequenceClassification`模块结构\n", + "\n", + "```\n", + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)\n", + "\n", + "task = 'sst2'\n", + "model_checkpoint = 'distilbert-base-uncased' # 'bert-base-uncased'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n", + "\n", + " 本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v1.0`结合的案例\n", + "\n", + " 以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v1.0`的代码实践\n", + "\n", + "**P-Tuning v2**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n", + "\n", + " 其主要贡献在于,**在 PrefixTuning 等深度提示学习基础上**,**提升了其在分类标注等 NLU 任务的表现**\n", + "\n", + " 并使之在中等规模模型,主要是**参数量在 100M-1B 区间的模型上**,**获得与全参数微调相同的效果**\n", + "\n", + " 其结构如图所示,通过**在输入序列的分类符 [CLS] 之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n", + "\n", + " **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n", + "\n", + "\n", + "\n", + "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n", + "\n", + " 固定其参数不参与训练,**设置 pre_seq_len 长的 prefix_tokens 作为输入的提示前缀序列**\n", + "\n", + " **使用基于 nn.Embedding 的 prefix_encoder 为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n", + "\n", + " 拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩码`attention_mask`\n", + "\n", + " 将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括 loss 和 logits**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", + " for param in self.back_bone.parameters():\n", + " param.requires_grad = False\n", + " \n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " \n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", + " \n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n", + "\n", + " 根据`P-Tuning v2`论文:`Generally, simple classification tasks prefer shorter prompts (less than 20)`\n", + "\n", + " 此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", + "\n", + " 本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n", + "\n", + " 以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", + "\n", + " 数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n", + "\n", + "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n", + "\n", + " 构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21cbd92c3397497d84dc10f017ec96f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset, load_metric\n", + "\n", + "dataset = load_dataset('glue', task)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f44c5576b89f9e6b.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " return tokenizer(examples['sentence'], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,定义`SeqClsDataset`类、定义校对函数`collate_fn`,这里沿用`tutorial-E1`中的内容\n", + "\n", + " 同样需要注意/强调的是,**\\_\\_getitem\\_\\_ 函数的返回值必须和原始数据集中的属性对应**\n", + "\n", + " **collate_fn 函数的返回值必须和 train_step 和 evaluate_step 函数的参数匹配**" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " Dataset.__init__(self)\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item['input_ids'], item['attention_mask'], [item['label']] \n", + "\n", + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=1, # [0, 1],\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 使用`trainer.run`方法训练模型,同样每次只对验证集中的`10`个`batch`进行评估" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[22:53:00] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.540625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 173.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.5,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 160.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.509375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 163.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.634375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 203.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.6125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 196.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 216.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.64375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 206.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.665625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 213.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.659375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 211.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.696875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 223.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n", + "\n", + " 在`100M-1B`区间的模型,但是,**distilbert-base 的参数量仅为 66M**,无法触及其下限\n", + "\n", + "另一方面,**fastNLP v1.0 不支持 jupyter 多卡**,所以无法在笔者的电脑/服务器上,完成\n", + "\n", + " 合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb new file mode 100644 index 00000000..a5883416 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb @@ -0,0 +1,1086 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E3. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析\n", + "\n", + "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何使用 `paddlenlp` 自然语言处理库和 `fastNLP` 来完成比较简单的情感分析任务。\n", + "\n", + "1. 基础介绍:飞桨自然语言处理库 ``paddlenlp`` 和语义理解框架 ``ERNIE``\n", + "\n", + "2. 准备工作:使用 ``tokenizer`` 处理数据并构造 ``dataloader``\n", + "\n", + "3. 模型训练:加载 ``ERNIE`` 预训练模型,使用 ``fastNLP`` 进行训练" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:飞桨自然语言处理库 paddlenlp 和语义理解框架 ERNIE\n", + "\n", + "#### 1.1 飞桨自然语言处理库 paddlenlp\n", + "\n", + "``paddlenlp`` 是由百度以飞桨 ``PaddlePaddle`` 为核心开发的自然语言处理库,集成了多个数据集和 NLP 模型,包括百度自研的语义理解框架 ``ERNIE`` 。在本篇教程中,我们会以 ``paddlenlp`` 为基础,使用模型 ``ERNIE`` 完成中文情感分析任务。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.3.3\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"../\")\n", + "\n", + "import paddle\n", + "import paddlenlp\n", + "from paddlenlp.transformers import AutoTokenizer\n", + "from paddlenlp.transformers import AutoModelForSequenceClassification\n", + "\n", + "print(paddlenlp.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.2 语义理解框架 ERNIE\n", + "\n", + "``ERNIE(Enhanced Representation from kNowledge IntEgration)`` 是百度提出的基于知识增强的持续学习语义理解框架,至今已有 ``ERNIE 2.0``、``ERNIE 3.0``、``ERNIE-M``、``ERNIE-tiny`` 等多种预训练模型。``ERNIE 1.0`` 采用``Transformer Encoder`` 作为其语义表示的骨架,并改进了两种 ``mask`` 策略,分别为基于**短语**和**实体**(人名、组织等)的策略。在 ``ERNIE`` 中,由多个字组成的短语或者实体将作为一个统一单元,在训练的时候被统一地 ``mask`` 掉,这样可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n", + "\n", + "\n", + "\n", + "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n", + "\n", + "\n", + "\n", + "接下来,我们将展示如何在 ``fastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 使用 tokenizer 处理数据并构造 dataloader\n", + "\n", + "#### 2.1 加载中文数据集 ChnSentiCorp\n", + "\n", + "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集大小: 9600\n", + "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n", + "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n", + "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n" + ] + } + ], + "source": [ + "from paddlenlp.datasets import load_dataset\n", + "\n", + "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n", + "print(\"训练集大小:\", len(train_dataset))\n", + "for i in range(3):\n", + " print(train_dataset[i])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 处理数据\n", + "\n", + "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n", + "- 参数 ``max_length`` 代表句子的最大长度;\n", + "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n", + "- ``truncation=True`` 表示将长度过长的句子进行截断。\n", + "\n", + "至此,我们得到了每条数据长度均相同的数据集。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-22 21:31:04,168] [ INFO]\u001b[0m - We are using
[21:31:16] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.895833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1075.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.8975,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1077.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.911667,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1094.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9225,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1107.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9275,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1113.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.930833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1117.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.935833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1123.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.935833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1123.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9375,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.941667,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1130.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:28] INFO Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n", + " 6-22-21_29_12_898095/best_so_far with \n", + " acc#accuracy: 0.941667... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:34] INFO Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n", + " 98095/best_so_far... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import LRSchedCallback, LoadBestModelCallback\n", + "from fastNLP import Trainer, Accuracy\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 2\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks = [\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"paddle\",\n", + " optimizers=optimizer,\n", + " device=0,\n", + " n_epochs=n_epochs,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " evaluate_every=60,\n", + " metrics={\"accuracy\": Accuracy()},\n", + " callbacks=callbacks,\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 测试和评估\n", + "\n", + "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``fastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n", + "\n" + ], + "text/plain": [ + "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n", + "\n" + ], + "text/plain": [ + "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n", + "\n" + ], + "text/plain": [ + "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['交通方便;环境很好;服务态度很好 房间较小']\n", + "\n" + ], + "text/plain": [ + "text: ['交通方便;环境很好;服务态度很好 房间较小']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n", + "\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n", + "\n" + ], + "text/plain": [ + "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n", + "\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "def test_batch_step_fn(evaluator, batch):\n", + " input_ids = batch[\"input_ids\"]\n", + " attention_mask = batch[\"attention_mask\"]\n", + " token_type_ids = batch[\"token_type_ids\"]\n", + " logits = model(input_ids, attention_mask, token_type_ids)\n", + " predict = logits.argmax().item()\n", + " print(\"text:\", batch['text'])\n", + " print(\"labels:\", predict)\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=test_dataloader,\n", + " driver=\"paddle\",\n", + " device=0,\n", + " evaluate_batch_step_fn=test_batch_step_fn,\n", + ")\n", + "evaluator.run(5) " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb new file mode 100644 index 00000000..439d7f9f --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb @@ -0,0 +1,1510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E4. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务\n", + "\n", + "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何在 `fastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n", + "\n", + "1. 基础介绍:自然语言处理中的阅读理解任务\n", + "\n", + "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n", + "\n", + "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:自然语言处理中的阅读理解任务\n", + "\n", + "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n", + "\n", + "- 多项选择:让模型从多个答案选项中选出正确答案\n", + "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n", + "- 自由回答:不做限制,让模型自行生成答案\n", + "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n", + "\n", + "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n", + "\n", + "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.3.3\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"../\")\n", + "import paddle\n", + "import paddlenlp\n", + "\n", + "print(paddlenlp.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n", + "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n", + "\u001b[32m[2022-06-27 19:22:46,998] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n", + "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n", + "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n", + "训练集大小: 14520\n", + "验证集大小: 1417\n" + ] + } + ], + "source": [ + "from paddlenlp.datasets import load_dataset\n", + "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n", + "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n", + "for i in range(3):\n", + " print(train_dataset[i])\n", + "print(\"训练集大小:\", len(train_dataset))\n", + "print(\"验证集大小:\", len(val_dataset))\n", + "\n", + "MODEL_NAME = \"ernie-1.0-base-zh\"\n", + "from paddlenlp.transformers import ErnieTokenizer\n", + "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.1 处理训练集\n", + "\n", + "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n", + "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n" + ] + } + ], + "source": [ + "result = tokenizer(\n", + " [train_dataset[0][\"question\"]],\n", + " [train_dataset[0][\"context\"]],\n", + " stride=128,\n", + " max_length=256,\n", + " padding=\"max_length\",\n", + " return_dict=False\n", + ")\n", + "\n", + "print(len(result))\n", + "print(result[0].keys())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n", + "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n", + "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n" + ] + } + ], + "source": [ + "print(result[0][\"input_ids\"])\n", + "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n", + "print(result[0][\"token_type_ids\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n", + "\n", + "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n", + "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n", + "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n" + ] + } + ], + "source": [ + "print(result[0][\"offset_mapping\"][:20])\n", + "print(result[0][\"input_ids\"][:20])\n", + "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n", + "\n", + "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n", + "overflow_to_sample: 0\n", + "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n", + "overflow_to_sample: 0\n" + ] + } + ], + "source": [ + "for res in result:\n", + " tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n", + " print(\"\".join(tokens))\n", + " print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n", + "\n", + "基于以上信息,我们处理训练集的思路如下:\n", + "\n", + "1. 通过 `overflow_to_sample` 来获取原来的数据\n", + "2. 通过原数据的 `answers` 找到答案的起始位置\n", + "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n", + "\n", + "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n", + "处理后的训练集大小: 26198\n" + ] + } + ], + "source": [ + "max_length = 256\n", + "doc_stride = 128\n", + "def _process_train(data):\n", + "\n", + " contexts = [data[i][\"context\"] for i in range(len(data))]\n", + " questions = [data[i][\"question\"] for i in range(len(data))]\n", + "\n", + " tokenized_data_list = tokenizer(\n", + " questions,\n", + " contexts,\n", + " stride=doc_stride,\n", + " max_length=max_length,\n", + " padding=\"max_length\",\n", + " return_dict=False\n", + " )\n", + "\n", + " for i, tokenized_data in enumerate(tokenized_data_list):\n", + " # 获取 [CLS] 对应的位置\n", + " input_ids = tokenized_data[\"input_ids\"]\n", + " cls_index = input_ids.index(tokenizer.cls_token_id)\n", + "\n", + " # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n", + " # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n", + " offsets = tokenized_data[\"offset_mapping\"]\n", + " # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n", + " token_type_ids = tokenized_data[\"token_type_ids\"]\n", + "\n", + " # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n", + " # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n", + " sample_index = tokenized_data[\"overflow_to_sample\"]\n", + " answers = data[sample_index][\"answers\"]\n", + "\n", + " # answers 和 answer_starts 均为长度为 1 的 list\n", + " # 我们可以计算出答案的结束位置\n", + " start_char = answers[\"answer_start\"][0]\n", + " end_char = start_char + len(answers[\"text\"][0])\n", + "\n", + " token_start_index = 0\n", + " while token_type_ids[token_start_index] != 1:\n", + " token_start_index += 1\n", + "\n", + " token_end_index = len(input_ids) - 1\n", + " while token_type_ids[token_end_index] != 1:\n", + " token_end_index -= 1\n", + " # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n", + " token_end_index -= 1\n", + "\n", + " if not (offsets[token_start_index][0] <= start_char and\n", + " offsets[token_end_index][1] >= end_char):\n", + " # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n", + " tokenized_data_list[i][\"start_pos\"] = cls_index\n", + " tokenized_data_list[i][\"end_pos\"] = cls_index\n", + " else:\n", + " # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n", + " while token_start_index < len(offsets) and offsets[\n", + " token_start_index][0] <= start_char:\n", + " token_start_index += 1\n", + " tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n", + " while offsets[token_end_index][1] >= end_char:\n", + " token_end_index -= 1\n", + " tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n", + "\n", + " return tokenized_data_list\n", + "\n", + "train_dataset.map(_process_train, batched=True, num_workers=5)\n", + "print(train_dataset[0])\n", + "print(\"处理后的训练集大小:\", len(train_dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 处理验证集\n", + "\n", + "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n", + "\n", + "#### 3.1 损失函数\n", + "\n", + "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CrossEntropyLossForSquad(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(CrossEntropyLossForSquad, self).__init__()\n", + "\n", + " def forward(self, start_logits, end_logits, start_pos, end_pos):\n", + " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n", + " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n", + " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=start_logits, label=start_pos)\n", + " start_loss = paddle.mean(start_loss)\n", + " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=end_logits, label=end_pos)\n", + " end_loss = paddle.mean(end_loss)\n", + "\n", + " loss = (start_loss + end_loss) / 2\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 定义模型\n", + "\n", + "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `fastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n", + "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n", + "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import ErnieForQuestionAnswering\n", + "\n", + "class QAModel(paddle.nn.Layer):\n", + " def __init__(self, model_checkpoint):\n", + " super(QAModel, self).__init__()\n", + " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n", + " self.loss_func = CrossEntropyLossForSquad()\n", + "\n", + " def forward(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self.model(input_ids, token_type_ids)\n", + " return start_logits, end_logits\n", + "\n", + " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n", + "\n", + "model = QAModel(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 自定义 Metric 进行数据的评估\n", + "\n", + "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n", + "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n", + "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n", + "\n", + "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `fastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n", + "\n", + "在初始化之外,一个 `Metric` 还需要实现三个函数:\n", + "\n", + "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n", + "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n", + "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n", + " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `fastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n", + "\n", + "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import Metric\n", + "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n", + "import contextlib\n", + "\n", + "class SquadEvaluateMetric(Metric):\n", + " def __init__(self, examples, features, testing=False):\n", + " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n", + " self.examples = examples\n", + " self.features = features\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + " self.testing = testing\n", + "\n", + " def reset(self):\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + "\n", + " def update(self, start_logits, end_logits):\n", + " for start, end in zip(start_logits, end_logits):\n", + " self.all_start_logits.append(start.numpy())\n", + " self.all_end_logits.append(end.numpy())\n", + "\n", + " def get_metric(self):\n", + " all_predictions, _, _ = compute_prediction(\n", + " self.examples, self.features[:len(self.all_start_logits)],\n", + " (self.all_start_logits, self.all_end_logits),\n", + " False, 20, 30\n", + " )\n", + " with contextlib.redirect_stdout(None):\n", + " result = squad_evaluate(\n", + " examples=self.examples,\n", + " preds=all_predictions,\n", + " is_whitespace_splited=False\n", + " )\n", + "\n", + " if self.testing:\n", + " self.print_predictions(all_predictions)\n", + " return result\n", + "\n", + " def print_predictions(self, preds):\n", + " for i, data in enumerate(self.examples):\n", + " if i >= 5:\n", + " break\n", + " print()\n", + " print(\"原文:\", data[\"context\"])\n", + " print(\"问题:\", data[\"question\"], \\\n", + " \"答案:\", preds[data[\"id\"]], \\\n", + " \"正确答案:\", data[\"answers\"][\"text\"])\n", + "\n", + "metric = SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.4 训练\n", + "\n", + "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `fastNLP` 吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[19:04:54] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 49.25899788285109,\n", + " \"f1#squad\": 66.55559127349602,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 49.25899788285109,\n", + " \"HasAns_f1#squad\": 66.55559127349602,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 57.37473535638673,\n", + " \"f1#squad\": 70.93036525200617,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 57.37473535638673,\n", + " \"HasAns_f1#squad\": 70.93036525200617,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 63.86732533521524,\n", + " \"f1#squad\": 78.62546663568186,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 63.86732533521524,\n", + " \"HasAns_f1#squad\": 78.62546663568186,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 64.92589978828511,\n", + " \"f1#squad\": 79.36746074079691,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 64.92589978828511,\n", + " \"HasAns_f1#squad\": 79.36746074079691,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.70218772053634,\n", + " \"f1#squad\": 80.33295482054824,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.70218772053634,\n", + " \"HasAns_f1#squad\": 80.33295482054824,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.41990119971771,\n", + " \"f1#squad\": 79.7483487059053,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.41990119971771,\n", + " \"HasAns_f1#squad\": 79.7483487059053,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 66.61961891319689,\n", + " \"f1#squad\": 80.32432238994133,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 66.61961891319689,\n", + " \"HasAns_f1#squad\": 80.32432238994133,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.84333098094567,\n", + " \"f1#squad\": 79.23169801265415,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.84333098094567,\n", + " \"HasAns_f1#squad\": 79.23169801265415,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[19:20:28] INFO Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n", + " 2022-06-27-19_00_15_388554/best_so_far \n", + " with f1#squad: 80.33295482054824... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n", + " 0_15_388554/best_so_far... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 1\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks=[\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " device=1,\n", + " optimizers=optimizer,\n", + " n_epochs=n_epochs,\n", + " callbacks=callbacks,\n", + " evaluate_every=100,\n", + " metrics={\"squad\": metric},\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.5 测试\n", + "\n", + "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n", + "\n" + ], + "text/plain": [ + "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n", + "\n" + ], + "text/plain": [ + "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n", + "\n" + ], + "text/plain": [ + "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n", + "\n" + ], + "text/plain": [ + "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n", + "\n" + ], + "text/plain": [ + "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n", + "\n" + ], + "text/plain": [ + "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n", + "\n" + ], + "text/plain": [ + "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n", + "\n" + ], + "text/plain": [ + "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n", + "\n" + ], + "text/plain": [ + "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n", + "\n" + ], + "text/plain": [ + "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " 'exact#squad': 65.70218772053634,\n", + " 'f1#squad': 80.33295482054824,\n", + " 'total#squad': 1417,\n", + " 'HasAns_exact#squad': 65.70218772053634,\n", + " 'HasAns_f1#squad': 80.33295482054824,\n", + " 'HasAns_total#squad': 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=val_dataloader,\n", + " device=1,\n", + " metrics={\n", + " \"squad\": SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + " testing=True,\n", + " ),\n", + " },\n", + ")\n", + "result = evaluator.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/figures/E1-fig-glue-benchmark.png b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 0000000000000000000000000000000000000000..515db700127d0d2f4dc7b139a29bf1749419f2c3 GIT binary patch literal 158817 zcmcG#WpEuyuqJ9TGc#Ds%*@PWF<2H^Y%w$Ih?$wmVzii
{%+)wbF~--jxD7_0H_*5lDgQCt!RL>=MhC4NCnZk6_L|-%ru< z!~1{pa&BQJ_=@Z&ZfOl_@a{`O()2wF{^9S^LT4}ln-8?wj}m{c+_iIEYmSmc@%^@p zTcXnoYCO5ENc>^?IrR3>?Jly-_Xfh+6`=Sh8UrsHSPo(+AQvYht@(9APJ<&gOfwX0 z%$>>h-8w;a4JWzbHspS_iejBsLPL(<*oenUug?b}CS~>6h<9@tb3nGcE|30`MpJjw zjrSqw)e6=?M#?Pw8vifFiV~NR3Fqqn7g8;fw4x%CgoMQZDXC|+SP=yc?W@ASBV=+aN~7D%^z-M>zS8!) zYTW-K?%bPh38UluOCS30(|G>v>HjD9=Ys@7A40Huc1mlTQQ4SFBtHQtRz?QC$rCo( zSqr@NA!7!&LCPXk&Uoc=z=fpyh#0(qvbj`5rThd|JDCv(4&3~dzciQ|-jgXWxV1{u z)Ge=$iy35msiF*dx7sM4c(~K7Y{jYIZ9vFkZ{Oe1Qle@LvKEfeFQ~|oV*+v&0T9Ls z{avnUv>JB9<+CRsVoBGZf3x}LS(LN(fCXNQ10CtgadAWIh3f!;ai*?!%gdpa>#kD` zx~qLb9PU&il<`mCzu0(Bkv~{iYOH{i`FCWEM+|gzzwtc#kTnc zy!YePa}IGOdarFA`H~O3bzaCKOs3>pCs&1~Mo6M=##8MO-IV7VDV=kpxXMrMY+oe0 z2ctiAx#1F*CW2z J07^f7`R{-iKO7|1*l-y;GbF2kt H9N|IN1|(wq^0y*d$Zr 4JprY6V$JAKxZnFh(6Z1q+a#0WV zFv8GR$__%34SI<8k3Hct>_Lr)x(b0i`QL|5eo1nsXQ%@UF*Bztt2(*vNTi6Lv`{{k zfjwIMAF|2m)xdv1kc*H%Lu1PE25!l5cI_cfq^!+@j&MrsN=r1iL4mTWJFDBjH}4RF z%IooRgKlUl04vG`jW@KPG)O6?5U>DkTi3mYeeDP154$DB&5rL}3Y)j$-|M(oLpB8A z9qRNjfUg&S_{H$P45Mr`J6zN0zM0e;k%d33gl fm zWAJL0bNeMJV9e_gE<*fyd*LYGuI2TMU#+bU6OOXt#n$W9Z1WKwwVY}GA5E A){KfX#)5Kg4cBOiD$;;VkzOsNZvs+ zp;SjKK9L>7syA+c0y~|Nx*5-e`S}v<)$XH;=XyCajh33wOQGhL*Om?&wOJG{Jx`94 z9@HFu4Mb-u9+dw2Kv}2L^!2Kj>*Fm$V`Nzvu>Wl8vRnY0$CylFat$W?6c^o?zsAv$ z=$5bxd=-)4&$V!a(nvdpgzbZHG;iPM{9~(It-6MZW_-i#z}55>54zWK@5h7cwDxyc zV4unpwVHD(2?NpbY?wTxU8{YW%y*5;0i_PGvwZp`m(GE9Qq`s5LmL`2I8k~y>%o`w zlQ$4-6e?GP+HsDsgFs* ~{r$Gyl znX*#MIl#CaEDDumqExBERLzgO+JN`rs}|bcir;~sZrJlaC>2A|b5p%hV!jDOdw%GQ zARRop`E|N%kW@DUGh8N@lBqHkV+4WU*VR6(f~?au5Y2Jr&qAZX^Y~`4>Y^8N6?fo+ z4$ mA>LoXlY-aHzfRrrJcaHH(T$)CaK3sH+ZAP;YF1ffsr}$#@cmZTE zs=qMh3U=k2M6OhUFS|i(n3xDe#j#Q~PTNuUD47gQmyjv}T#T4TumnC}o}mIo_yfL) z!fP}(fky|YGaFKS2>Ke9ZL4X${^F4zR@1Ue*6r~t_p#q?`n$+vpId>eNV><8@@&6N z0@Z{Hq%B3Uu}U2FLhjRpN?*(Js2eu4L%kjZH7c0-xF!WzEIM&HPlx%yIn*r+bx;qZ z9Qhz9YFr{ttn+EXWxi?D9kjcXxnGD#MrlzFWx7X<-faJFQYZ7ushZr9rf3=K2cbn{ z(&(k8L`ETX!&vE6J04k<2+3 {Btzhg&)T$x*~@vQd?0ym)8z@`*K*gH2HTe zGl4zQPzd|_-^D1>ftU1iUTsW?(rYld*5qG1%2 DQbY;fysYuOUN|MDVdD5MDb_QYd8RmN_aIF=$l z{2h|Y$SfjH#ec*b+uL9XZ$A83RKJ}$_^t$`g$3Tl=CKG#4BH7l{T|Y{g65BcCqhB} zqbByuI4p}PGAUf62mp~?>6L)nXsx9%#bpF@LX@^wD1(yy!rvwe4g}_*)@3sQAIUT2 zhH2=?*)#?x!8g-dSxOq>#i+}~=)#h!j8XDxpm!=i7=NUOw|P5#oNyqscgC&Q@PrRN zah5ZfVI4Og8P(3`LVyohBev~QG5F-#_VOhnI)%NEP>$_pZu=t8w$mp3dhg88x`0rK z81W!!qL*673|@YdYpm^8z`fFgakl#ok<5%$ADMKH?f1ZX&H&sb328JKU^@}8wwYSp zmy>W%toNjQr(T}|l0C#wNe?Hk&J#2M+3n89F3W*PLFe8?BSHm7e=zms?k~UO3GHvk z?!^|}L8!X0c>9wV^uvV?O0?v1#0!t2-XesELBiBk-+(_Smn{NchIA|e_D&(JQ{Z;C zGx^V0Ma?B6iGYBBHUCHnQ@a%LmO47=>;NX!mq8x&TqJU!F-tVcWf9pL*P%agBIUxJ zRe(mNc35`S61qN0bo9JY8g{rYe}7g!)6lpxxd>M&ln**hzeq{}R2a?}EF6gVl}O5W z#{^Eq9y&NvF3d3TvIpDMwW<8Xlg!Rtw`5Y!Gh5k6oRmhDo$ITm (puC~0H- z07*h#rh*KmeuY-Q00v =QXx8y64YXFA_a?N$S1ddO<(-i)6Vtv!x*u;V4-Z t z*yXX)Lcz-5-&1{VIOs2p&AP|9#$E~g(vx-|Gn48x_@Mz1G7YfP)!|2tDr4(zMsPt+ zoIg@~GJXIeG<#i;IU9J9aDbY|rT=4J*#Fp9D60KODBLRxk{T7e=z6uw_G8@w6)8DY z7GYmi1kt6!Z*9M%$L%KO68$X@?+ 95WSI4Gj$(!6b4+GEKAJEeHV_HV?g0Q#SB^@6~lxuboipx_qc%G06C# zQ)7dbv?L(XJsMv@->!)8%ZPCl?|q^uoCdet8!l?D7rC)F)=(ifTnX~wHTV~{MiS*y zmd5aBYOEd|zpKs>)ggGNW|zdQ7@Zz3n}0uyi25306y>4;(A_oudxsYo_-BVN%jp#~ z>=uj#CNOxZ;d8*#N6qV TLPPIaX~Q+h!{B zwBajhqZdZt!8%Q(bGehTMT|7Sb~C)#k;uS~5=PXTOcI_>zx~V(xc{+2?uhZJ?kcXQ zYANOVU-sCr7$@nH0AjWX)f`yo?oM4BIqeGgF*xbmW*6{~tRK_ M@=Nl(&SBQ2r|nf8weZY|jXCVSRjG-epqMh$0iQ`0UIM(Aj8H zBL`Qzde?e!2FsOLO9UX&lbpn^k{KR1+z_k2Ym+sd{Pp2-P5XE5!9B`RuI@K(#IZM+ zgWqKkH0VlYo6 5lW&gz!t8gsQI+J4<2KOV 6=v zzCnbEA-TIB>zqkTsR0m3Xy$6TkI~86d6M|o;gCSo30V%cO#E~{iOprt= 5>d+ s1eir z0eJu-Yez2kVDpR3=JlC!V&FLUHD($6EQi&m;~GQ1RI(n8#--Znxl#2eKAV$)IIK$j zZ}7mOA #STN#sA #kA3r_9bFyUhSSLkbou~H6ItOG1|N)BQdZs zv|u$>QeW%wnXzaWt9mcsD05V2YcEe(Fgp4I#_S(uW%DL0Y?T&MkewY^(iSX3NH)g@ z k{T`|vsrS3I%-z+w{Bt?ixd}OpDaVDdL)jx z8HwXW=ER_ncNPRfl5mM0Kv74^!-+B||FK ME#G*%6`Wln$ui;pWwZ6eC2Yg?F=#KmUh320-m2{Y11ozrp{tEh9oHUY#IsRu9l ztY5gb{&FGjsi!)}V?Vdi7p^!PxngY$ZKW=O%2JKQmAR@(=rkrHE+<<<>0k-d(($>s z-DENbIaqfGXZkm`3dB_DAnlyKnA5h&qD{resd&e7bKM=_@dHx)Tr=+7i#^`DJC3wV zlVUOAf_s5Cw2($CguK4vw-*7QAfhu|^@#8x@?&lN0l+7-E)jC(#|1j6DV5&G3nlJy z4mixbOWfhIQpkl*VSr6FyGO!dL*@V{3nkbYmj_vFYK&;p@~ -QZ(_bHhVidCFUj5@^> z2AsVtI+B~K#28luI#E*bLXF5{W5Yx_)4Ya~>t(5^h?}+*Wnj7avw2oAC{p*a$;cRw zr9?IK3hl-pN>PWcUR*7|+vI!;apz7A19M9E3Uwgtt3JA4aJ!Z7T2a$2;)e{Ypqf8s z58JP3lQSw^jp&qe#7ib_sS!kNVP39Jk T=Qamh9g+JWur4B7|SNMIf}lQZYhcpXe1=M$peyz*4qH&V7x zz{9C5WA>U)cwhYd#|h;jJN5~TO!`~Bz_%?-CHVQd^{qr^SLY9gc7J_g{Sid<@}*q0 zT>MRg34%Sx=2tE_tTN94l164|7>}@_6jw3|a#9i=aQy@4Ev>NZ?Z!qgL_!NC;^R^% z!j71{giQVl`OP~UKe~1o?AOt=A0*6cc!YkaTI#iuS8`BQq3*B%Nk(Ff`_YY++e}-f zA!c=ak@gPc0vZ@zDA| ZN6!9WZ>$1+`k}`)kpkFi=aykU}ch z?R;MiEe&7_;b3ROUHkHhVP%YwPj{kn$XFA~)!w5nM{1Ee{$7Cqc7J8i)lMqGGssy5 z3oCZ^c65L#sb!>hmseuncQ9*IJBfhUWB#?$V|%m#eyxUl-i(TVB7HGHx}>@8$1@=? z|J=N %$xvqnJ{UdXvN)^x z1Sh0y$ERD7- gnt3jviTeQC!oJ-b^T2Sf8YF7{P#j>l|FZv!KR~0-Q(l(8ODEg`7gmsA!;4& zPI9R%;u#}4A7vohiv~VHZEbD;bA+(;Gw1K|`J}r4e+}?4|5v#0zlZ97liW9wi~t?3 z*?oC~&=;}v-9B2Iv_~{Q2fg@O%@dpy*w?iTl^Y4yKWkaY0)5 zJ_@7QiJY|W3?z#pJU*{rONYyC8SC4df9*rDSp5YUUpO!!nEBqw&bxmpBvug@IS$#? z0-@;E0%j1_)KImY|M~X9A97D7uRss%7vVeqzUL~~c<@hAbB0k=<|a;)g`wz@9YZ8u zUf!kU)!y25CC{!~=+}+%y0Sl+pFs`FcLbyA
dzV%6gA;S$4?P{kGrU3q1GzP2-<)#KqwfO4%Vj z{C_K6qCa5fs_r0~Z`@kp&$6LYp}0C@6d_{FD5b~Ml(E8u^!%zK?B}ODp2H?c ^Oo7+?q3$QU^_IL>ZB!7r+4is4gOf|aW zq-Y)oIBmdQ!F}#ASX^Dkr@guDNr&>hg9{X_sTneC8B%z*fl6_Y7%)CT>l-5RxR{iw z;023YKD+tIT_cH9D+jva1wGp`@@+Fk$c0=V{C)Vvu^KFKsYnH2%0Rw*!UXjCe8z8S z961nGFhD-xF^n{g-k9+Ew@ku=L*38wy$kEo&{`Y}Y8A3K)w#IV_dBsi4SDI21jD~# z387%mbH@N=BSclzc2A*~4^DsW72J5W=7BkB&>YEmpJrKVAd5EawU$=ME!$=QJdX`> z6-`n#HAVt<9RN~i1ls-1!X6P-(R7LdSGyK`+4BysB@9KVR7pu@2=Q1|MrO)Zri*X3 z#XUi|-pX8@#^?s)E$JC72|zvUqb^HR(y~DQeZ~j@qMemlP32D-;&;*_1eO!!U!)qZ z$N9=CD!Zo7ooR%^Hj1iMg~`yLA7vf(pEd~<-^7os=C{OI#0vT0gf-_X5S^(TmU|~7 zZ19TPV?b_+)MHz$TT@d?9fJ%2cVebkj5IXp24&f1v?ztGVZ81PsT5iAb6O%Q2xQqs zeqY<3bgA1+?m@##x{x%VC&X6D_vKPGR3plVoD&&)vR*n#4`3jbZn`=sa+PZ#j1CPA zy`3EWf{&}G;W#l*U|n#HTB2K kt=l6*9VoTn8 zQr)8UK$WalRKS9bULAT(?wGY9P5*3|Z-$*km>yW>V4o!i=~hc)$218%j@8bYlu!)h z=MR=H6_mQkdKniWrItxzgktsbI+TX4o_zIVcX1-Z7!$if##`7Bw_i10d7T|Wbmkd$ zvLW+$3u-wb?0CX9uETmyYePJm7K8QrJyP>~h?0N|$0R$%WnVM0c5SE^Upwg1YnJKg zVo-`&2}$+ik+31EubPE+H 8*+{oZkYd2U_E2kJq zAFwCQ$yIYXq9cIi3kV6#l%vn1*>P@}N|~YqOI-YKD|(gf{@hKw<}pL~nn?lH7^w|C z;rHty!$mhX$c8G=^zf)H^X_rlC*`sECv2`A@Scq($_J>hdsHdMmMou+&C{KjQRN(M z&R> @X~>yu*zrYaJKI31s`;QradA7{9z-xAd_4~AZ>Nl! zb)6-bB=J*-j}{ w{zb^AfX%e{er<-3RL z(4k2bekWuH+xs1{RBOu0VXN@=vp<&? EJMK zzEN3=l9f1(AeP7X%3TdF3-5J&IN`dPB*E|43Z6P8JW-SAeprFQ=Y9Axxw(Ry9%<)- z1iqgqk&J(ih?Gh6y<-!$PZIn@*VYf(n=sW`U(Fr`S|O7Ng>k}^D9GC>l5iD0h)gCB zUn_|I8v=9zZ1+Z2bhl+6_S_8GfvI#*^TBzu_dXFmk2{D>>Q|$7EDl7MftfaYsc0`h z;=3m=6v|f`LVdiDUbiE?KJr1Jc#Pe7fI<5ng)pkf`ovDC9lgPQwC8VpVwq_7aNYyJ z&aJ$G*s;Fr@HF(i^hrRkK1{8yU&^K26*HfPt_w?^Xu~E%h}co&Lypj`s%8eSH~hi* z<5!8)(7w|uZoT2y%Qh!mNj-)gVZ&3MS!wm3U(PcH!$m{%?3VdY&vxX=f_i#+<+K2> z(*jubr@LO6kh~;geC@e}bz&H20TF`=K!K+HTE%Y;*lTN9VT%DVw(hz;_Q_V49)%v{ z @yHfQlN>;hc2i1N{^&MFwsGm@~dUo!T-f%&`pR--+{ z3&-Tk#Q&hmTF)99F&@jlnI=1Xf`kACjEkEM?x)nBalF1=0Jqv1N(_%e{b;L6t^*8a ztL==>gwcuE;MrxJx(@|RNi-o%HnZE1$Brb=hfkR-@9mb3#CLkes;5l5-2y2>w)*P3 zwNU2atQsc+O`d<#eNz!=gM1*9O-Ucgjf5EpVD;JY$=AtIsT5w?jSYa60>bg}Fhg1Y zv{vrg^!&-|L?EzDR10y;O2*6(Ogfx=c~fG9&Hl!y>&*~US}LsXY>g^28r$63j)k8t z_vd{OCWgWtSiON^v}plx{@oqT@3 hfAi^&zv2oXGUvME|eDo9_<`!H-{bm+#RAewBg9mh+@!JqI+MyRhl6 zz2F%%Q6X8;S}irEhk-R^kv9Af#?CS*uD0vaNpN>}cXx-z9U6BF?(Ptzad&rj0tAQP z?(V_eHR$lX^U3>7%~Z|b?o+3K^nFh4z1OwZIxF=!hDr4_$oFwIf4z9VsZJJqAdxcS zaUk_d8xwZaxH`OC+O#lK@v!Y*^`2n13p>{#dIO5<58UX-Sg;xi5mHITAN|Gh(rA<{ z#c {A=k zp2{Uo#>#P4uZqt5_^wAE*QY#Cg$&1)OY8`JJB+CC Vf _=?n%5`p3Jc8)uS<^{L1O&~qPLBqj6EQlRm&gD;#xJ49HE5;_%Bqo{MmM0L(4Ry zc^vcgHj1m5H5=LyYq6?21^`rF$&AMzP$r&2(tDk)M9Exn(AUDT5$UP q=5|oW zK(V<*y7nV~vdqNf{5nP{soHhq)kj@-mpXK(Y&EqoFy0_)N{`vWfQhoPUT5?h{UB#Q z383BFIIUMk!0aYxSi#oMK`r=#G%-QzHl@EAg_f}f@7QPb7yTAL6ONS^(wpaENUSQ+ zFfeC9lX$aG)2jciDWPZcGWpQ)q)W5uF7fCoTz_Y~ud|kk+(6}+>jnOsrnG^EDeG|| zz2qL7Kk7v+O@e2|Zy|Wzd7l*6!88+6AJZ#Xakx~P3!Fm-5iwZPXJK{nHvSf+N*K(} zn^9+-t>jI#D ^o*MTp{McN>4Za23WYgAX2p)4|B zg*$1M%?_cZxw&R&bF-o3P%;pO%(gH&>8U8@FCroU-5-6fcSrfiHP;XnrX#>he*-%y zdI@SiZzeY0`a;WqUz)Sq(HMZ+lyG!#ccQAoZO?DBIh5$(1;4_nEhh*|BLAAtdn_$H zYb 3go3e?YrNMgwln{7jc1TV)~S#&K}3GJX^6_-|C zGE?acL`}$WE44-0DTttIYz-yg&^x0%t4v|_VB^6+t5NLZWjFBMhCvq`D+@<%E-nt8 ziEr!;NO=9Qkpxk&((=9Qu3mB!xC4n^-Ww$w2q-`>mQwMvqQ_FJY2D#2-fq55qYnAs zMQ}vSOXo>%2XBugyuOhJC*nv`JnwN|imZ0*?$Rb09d9;@QeDbizhpjZj)9nMrQqyf z_-djkyW2G^4v!+pLzCpJUGu|bW2TFJZ_C~Lp4rc0=UglgEh?xrkZ|=P+vIcsl8ont ze>?GjLG>>*eU@ftXmJS%V(~+g%33&5>K%~xIZcTdpi1X(O{Z*~R9NVxTJdwtdB+{Q zn+4r9$Ga =aKTU14q%2}2eYzJCHw3S|wgiMa4}C?1#dDW#qht9WKXEK{Jq=Cx3@f5ml7y zeVvlWyKXs$c0p^G8Ky%TxJnu9?OGoySuh`HXP|e#GZOM(dl)Z65H0$%e{0Cke}}i2 z^{})VF!4#sjdGp@Sz1r^ouAOUG{=ixyYQL4`yklk6$Fyj?)XotIBZ$MJxo~z0AHUz zyrn`ee{CN%)Z_c{b>Vi7n$DT-2_@X|LQ13hV&fFYKA35?k;{=gUim#~M(K<==)EKO z`Oly P7o=vPEmIySJqKI_SCdK#6VtOAvp)Y~D;#m)#$zP=Udi-b j!1G8qpUGFsCYC;|Rt~+R%t(A4J;&8dyhJha)n3h}{1X=yLLW2mU zmz0q#nGIPz-KIXGd(TZX>&Y6THWRedqyFluOtmlPHn6!XakQ_iVot30a6PQ-JnzxS z5)Z&|gzvV$l96>++lIOxif6@U`9-)X=wgMcKSVB4U>|%ZJAqs*n*r^(jEA%p7JMlJ z#IE*RkNB&1ZbX^d+qfxpS6d3DR>gPWVWsle;w;N@nZ%G4s;;o6QSOQ5GN cE>W)=txscpcwGu(WGC@ zd%I_MUf&9`jORP@$!J1<&Q~p2H@z5pKmUN3hLSveuW#QsXL`o?In7tp>bagfeTCr7<>wQERmn4m#_>xd|Y8q7~ z<+7KsI3hkD89>=)0~zgDGPJiI9Gl+-!dl>o0lL_L-8twl?9*mjxu@!t-rr=}-*~Ym z1>5(399MS%|EVCeW4SAgeU0+vxuK_<;iCUCVUmm bF(o;eh<%Q7UYvgMz`9 z%>EA*R35|3_CSkBgyd^Jm}3&IQRn6j>lJH5L9RhsPVpRb76AlOe0zKi@60FqE^Cm+ z-BOJcMlxh&GPh!jZO{x&KGR(PGghf*h0<+zk|Hr Wk2I+4g1J%jk5M z8C*XtfOa*g|G%~5LLMmsR26iFl6Dpoc6>Hu*I+H!N7 (lebJg<>&H!PG j_) zcYMh5-3OiY{H4VqXt<6aIIB`&SKgrxuRYet10)C0Qcr~B5K7BpREepW;@6~Jab-(H zmHPgucK9wLaieIb&rQ (zgVvet9mNeW57s06Lfn;kjq(A}1JRq$| z?|m6OZ z%2!_B$Nz81;igifjX*d+;)SL7T3Gp555*^2ePJMj-2PH`SP+=O=8gbXd zN=uX$tt&pSF1lI?ab{A6Y5^XTfyM@G%_oVI?V*btAXtb4r6Ukm>6Q{QKaLJs_*>rf z72c(wt-l}q1WR^ijJXJFs$sHr?-z&C0zC|uzfDleF4v!54?Ba)$3aSfPLF>bWsw_t z^LU*V;ahd){6DC7^wb(t?UX{-3sLKeG6R{(a```$XSniiH0(=x&bqMTYfR*R2VG zLk)6ByglE~37%saETPM(&QM)1Nig}DacLx7=?Q%y=<*K9IrS6I#RR*08L(a!tbmNd z@LT4vgZ8`j=x#8dw}5_DO%cVo{>HS$O+)*1p&y0Ne7U`MQMp6xM&FF??hjas=i5$A z3Y4^#*lrg+3Fx?i!- k3g;jc`-Oj5THb%H k&laZMsJ~s)-jv zzZ8TYi S)e%0s&FOBiUu9Jn j{DSa0%XI E T}h{D_W!rhf) zTb4$#&h*v8Mv |sOZmDij*LkDwMX ww2MB9Q- z*-4S#d0MYv*$3kMl8OC&5!W7JF<#PCG+pz+2S0v(0VB9ac6EY61X1?5j*(zL68wN| zTgn@U6 O8P|=dv-UUsKKG1$Q?uz*!QD@VN=K6V_&0a z2A>Exv0gxifV;uyT@IxPQj6tiaWX}t >}0u$C2rKK1J|o=r~#LbOb%#O zBGy>v3lvpskn&FOaoBa!oJ^(D#vMQaRd4Y}kZB1_7Qu#$b-~&DZu!xwJ;8~^!53RC z1fBDJDm|C4!JA+o!b)opnzIYfZx<2N2D@rius1jNw|`L3aqvSFg;R}#^L#7o{D$<` z%NG=VHnV(H0V0}DOr50>jpSZ_=k$B!!AYT&+x+{G8Bs)GKQC>5qmpm+D0V{a#OD3d ztSh41bSn4Noh_iQ?!0oilTweRuRuQgiG(%F$j@Vii2H*V5~iii1d7o{FS>oATf$xz z>dx8jTVI;vTkKH=vyZw*RDMpxL+Av+w%-^WlarE&+A~of%+yb8?h>-70nhZ2F{Gi4g&@O3A1ooHtgudl=XP#2 zX(>5{Y9KU5cv8|CULSmg;Y|j}LPa$@stnoD3{!%>v($H!yOx@8dzpocTngSum%B4h z-tjsN3sI4x2rQ$Jei{8dGhMhq!JD#Xun+#y3byB(XUuWpH~rxlhE>>sE9_^p`}>uq zog-GWIC5~pbZa-Pm&Ep}xQb@m{}|S1`^URzh80wB;*p4=qOYU|w-8MEp=5yMP)rYl z$T!j5-s9*Z^iYp;5kPX=5g&E#W^Z;Jz=G9E^O;U%fkXD3=on+%UB-a%ywpnT!ktpQ zefOn)A^+nU{y?rX{P}II>)y+hDCR&r3~MK1(voJ{0$+lQ9y**BmnrgopruxBz 77F|mV*!^S0h1CsR*48eWCjundo!jZA=*adzhprXz0L}rMk)kaf> z4`a*EQnB{a1-S8YZ2k|pDZAX;|3BF#UaXv|KsHDYE P~-FMhZICpy=Z_BU4jp8C5Kgf)8Fg(oUFsT4l&eXt+yxpF5c9A z6 4o%IY2UX3Ck7 zG5BYkNoZSa^|qZhI9oH~9_gILZ$Xz`7VnH4qDN$P^XnZg{&u=SdR~O+)p@7paw!Gm zIkV_oia5vDM2}M}`enRO-iedqSeKmiM()+-S~B!NXc#WPp#(CNjPp7hbgRgCQCfhP zJ)C;ZA_5yk$Kx9qS3ZE#M# uwMbomc_5=q4a*SLmrfw5J32dsk%= z_=Az)^-3H3{I?iXj=8;1d56`QsDvVgtdOzp2mhY71Y4598NAf;%K`g9Z={clR5sd( zcHWHq#F0IvW{^S=Nq3_Ge5=M4!x)p~rXFuypk`ZNiRCW+9^nKRtye)oVov* _yJqEpbDe{N8{ zrO_}2?9;h+#6l|nwxcUjicdRc4ImsuO4K*80k(7jQKpfHQ-^Ygs~msg!6b7xm_oF; z?HOvSDX^( 6R;w%|I*ZDF-L1cY%VpcwJOY@nx2IYV{_&SrO$|LPAfGF$*O@%T@xKMPZsq6#lZa z|F*pj1! f4P24SF-hqR2gd46f+MCWeM}k zrbmC@MZv1vM4_q5E;W>b;7R)0&RnD{8 a--G5#ePMKZaw~OdV;ELo2FkNvLYLp-bqGbL zcp2~Ia*Y--qGD;aWN5^@Avj>t3dRxDF4$lKu^78>OeC Py P6r33WUQ)X5`0p5oB;S zp)&5`r-02H#~J}PGRh1(NxB45OPf`&Ut0Q?cn+#9+*jF3kro >Lt6e54k zw^ FSU}gxk z39Mku!k?1+kQbS&M$ywVi~ygK>ZwJ{KzC9F1m011R^$2JG&FaHr{GqW<%+0M<(k<6 z?9P2^5yP+V-v`)u>zybOYBY4;1mz~$Oes!grcFy}(zcqiI*W!H&7GWo(i=+>P%#t; zExw5eMO%s|SGeSRmd;<6DBfovA~g}Gi^bXuX3b6l`AeEM=-DrsL+8Urn_Ws!M7>qG zO7qz-r%h!1Gi#sqI+rZRUT5aV0gPFSFgR+*vbX9)6tQWCF22!hD`QQEHHIsg>^75r z*BS=jApU+d#0?@GD@5DEmP3(Z)uP0SnaLp9KF FQ#-d!Tlmme2p z@8LNHq_|W|4bIMQ!7inO(5fFNo-RH ^wu(n&WEsHJ$r-BVc-_vZ;M$h^hRxlNbU)0RjvdK{s@4t*Bv(^HP~ z1^Mqum!Y5?)jfMcT%(2P$!0Z7(-BX?AcA`VCC8~!?B7yqYL)4MtX!FeXwaSPvb#xp zSL3Fvrq-Q!###$F%sp4aaP&yYfY|CY@s9^2cB}{GF{ HfK??|Hk@R|Jy*Uv*F#KX>7yi9f-?`5 z_1;I|=kb-mo7Z%7K~7e2lMq57Z-xg6iRjA(`^zKR9v$U^*2N>9v4wXm=^ppv*K3_d zp_--UU(-m1-J~zQy?;?BrNepOj@mJ(hJ`q~C^qua&n2d}8 ^LZR3B6Q_@n13Ku%8L)K8CyFth zb4t{?Fl(jxPTDtRpE|YSvvakxqXs-|mfQtAgWG0*5w$6UzMa_0luKoA-*xxppg-PV zz6q>_pTRfX$`gc6->|)BG?9DRmR?7Xj|e;aR?IP>IUR9kBAT{bx%B50uspaP56NBH z#qzPm&~oU=Y}%6>(0oqe$$l~#WUoeP3;Ww{iDzOW6k|?=&a}PM@37jJ!t==A@~tTw zk?69&P%+8Gh^o&lM8Sj+XXq|P7f`BGl=(!8u$xZ8<1RG8aNyNNF$RZ?(FPs&o?Mn^ z260ZLpt}&4rgzWetTBO-WTbnr)aM>I@XO+O>B%8RdMqj5ejWaL;QyYS%9>IO<2jeU zz~11jCGV+NU@x0(UPYQZ6prrTC@^uMF5 !M)KXyFbM-AKB7H{U-Q1~swWguZy zUQaqyvAu(%Tc%YY6IUy>^u$(TCb0Y0T#bCBg+=56yUm8LOm~~JCZdOLT8K8ZDS`kV zt~0Gd*2waqWWl@S@bpc%Sg^->P_sXFL9zt3krpBbU9u?3Mmvf9rx|@KR|lAJzCpz| z+Lg3@uXV7L<^(d|KVD;SB-gCy^))f89M#0CQfxA?0c)L7NfmlH6tfC-$Oh|ne`*ZC zfw1V}IH4J3T8rdwcC;_X<;FqV`&f~vh!50^o8yYGw}IxtG-q_llOFEM7@^4aOyK1* zy^}qtl=ia1PHkDFu!;hN+n`4Dv$JZW44$`$Es#FN%dTPk4EJ`!x8%2VWcmgXwNO;c z%law$|AaiHg$TEaQO@>5mr)Vx6osTbJSBe5ccIFdw}CRAc3^yRL1YTu;G-J!p#_(6 z8bVR@k2)H&YI2s$cn97JyK6T-`+i=*uUyM;e)SZ83=MRnx}~9SZ^VMj1!#{Ydixzb zyuks+Y7_!pR^+7XYhVGWbgI%iLofdWKi&E~aph3_2qIr)BkwfWD$9GfiFBbCQ*aKf zsCRY)0wGAkVM1cvkrT#8^MRk1JhiZnB%4>T!2F_YF#4^GAJD w}FF?%{259wd1rXY&0%_qO}5qIK>UZ4{EG1MX4a$zxP z@Ix#W=V)~eM<7EZz;)ogW!ZwojAOe8?Q_GKWl#{~r{}k8`%l=DC>r&vx%&1{tBwPk z1SDAyw8s{bux1kB=pr>-B2ylZhW0wT^rjtjie~=xmBIC6!X)!y{(t;34nO3-M^VHw z|9?eRB>Uqcqt_z8?$H2DFm~XHS@_YFWgJQ2KA}#J8+P<#A!#uW3-rzFL?kL6*noGY zxH8nzm7(GY@vFW{ZqFyy9P-H9P0{}ZIt3LgmNmjWw>-+DM33gop(DMQ)X-xOk=C!$ zp>{))(W7*znEJTXv&{aXTFpslo-h^k9|lTu2K*U^dIY=p?S+-ZH+n*@Ck)l*A@93R zqRbvl4Rvbq1}ERYit=%ryv>bc;9wM7*!+GnQ>V64K{9Z)gM4BARkmbHZ}!2APLQ=E z`ho^Q_`Z&Uaeii)Iu|R-_j;-C1ty_xOlUJ0%!x>Cqz%7lDqG;7tiz7igjqAQo^ZuN z{+v5tU$~a%*5{|`afWkeb8c8u`=)^L6CtQ_#oU*1f#{_IyPglV;)+g_&z5 3cg(TwLu`|in3})-r@2nl3d2xZVi)w4Tc~24(BxEYbl&+`{TjoUDN)vc zuF)tkUt?KjAPP$C`g-N|C{p*~IbKT;d_Lyh$MA=zQAnAD_lf0|G$Q|PK6%+~7@~AC z3pCdJ;BoPS-PPlkcIJd>uF(2tI%XklW#U_H4!?xb-P&abrJ!bv-MRT6uY8Fk0UvCy zraH;|w9=MFJNnhfXRw@0J8#3*_*ZbK(q{4ElOV^&z3*q84&eR@L_bp0WB@OV9u$Q3 zFL2(SeJD&bK>-0GA+hv=pB)TkojZF#qnK#alO@Zu78U9HdFkUE#%Z*P|9jZ=o3~+q zpg_w9t2>G^Mh9BoT1@Z2DxHJR2XB?;10H9)?qJ9P5=%oQ+xwjwR0cW)Zp?4R$d?{o zmiO-C{o57f-Ws5M5|bS$W1*TD>yY*HF*`$}VhE7o7%Oa1lUHqi4t9Q;+UH*O7k$GW z>PYY~!`B-7u4*<9BZGce4CLPcY%q+MOo-cF>H+VRv}sX zCrFhJsTXn5y#?FP4xFD!OQ6OTDY>uKv|E{esXkYYm%tzE(2w7td7G)HMA*P8YUmR4 z@-k#sV (%P8op+dC=7?HS1`!sk25dkzX {5>PEKy*+ j@~5YZbCG3P+dts-gF3YJCX4gvw+BELOZ?PEsOj~|oBC|f<5 zz0QcQiM%6t-LrVAYoMc0&0$D1S?#k&!+JeG1MVZ<@UAtmcGp~E^O!9|j~mu{<+5_q z8&&CC?Dg5lpz+-=h!rH%#Sd#RHa9$q!bCJ3h`)iU2=ze6zRmI;jww$r!m{qG*kpFY zBFz$dtL$9on>ZwDh?XimT%NhrhQWwBXk)|X{&CpzWu`@bo5wheA?j=8dk<5lBxM {QzeAiP5{l03itnAzRQ;C|LV@kDT+ENhNa7D zrPxssb?{LNLQvloEr+PuASwTS`=2oQzs622eEi=U?3@(;jl(>!BEBLI_VOT;TaKY! zg~E!>?tiO`TJu $d_fys zwlt!cRHQSZk|a~T)nf(}xbz8?nsTKa_UAx*lI;*9Fp{Y(JxCiHiW2y{ a?q_jB=-Sa~H6Y8LFh+DtfFMTbflv$0rTY36_b?L-~uc1J4%lp1L~ zRKIDh4@!|EkSrYZkr$%h5^n~m^zUy{q{P6pQGpE+9UFxVs0_QRFGR&M_dm@P--{0G zE5?rFMdyj1KF(5fBcamVNbG3g+LP%jwt58zNd3LRj&9;24i|$C<-9%Cl|mJ8{~L3X zp^;Cb!zKVo*pX?dF!Ko}8N5WTCNfs^@;aEe=qJ_genj23WTy0-TSSpB-w)fe=pIvE z?W#jsU}uu|Wz)N2x!^l`+0-lFIvQcO^x*r+wpdsYQ)rRd{2Yd{=WYWdf3gxbntr5Q zAK*P~^1(H~T!4%-OaU(a_2+qeI$`9-iwdD#QMf{kDo`r3Xut}1*!5Pnm6v ip!yI* k0xoAfa4eQWLw3V zEZJL*Tz1c9#WR#wIXYJCzvAgdPp4QW88ut AWEW+ujr)!H22t z&Ov24!pVME$5%`==%_uvAfwrZBQJC!Ded@iz@2n=GsrYQ0?^C-OpDmPw20L~r!nxJ z4p=!K2I!e3WB@x^Dvw&4ISZ8~T4vQrqM{Kh=ftq;<#PId;*AmA{Q4?Fa6CC#1SwwB zD^8+YhMK>l*^0%$y(4QM%{ZFs;${X$OG=KI5duX5RKIBY5%Y9FALx4p2@c|7urm}6 zY4R_6;VIwx*&Mkd`4%a|0izU8Yg>L9c4@IVF3nyS4VFIoxfr+@e4raPP04PL<{cc8 zeyUR4YL|5F@_pt<+#hK?D9%DgxG80+64({0V*xzm`GgKGmF^qJjB)AOG@j&g!K7cF zW@XA2rvF^t8V|{N3V@peeaE&Zw}(zYP^Y2=dQlTdBb&pltevmik4RIAs`9VM37loj z8aL0=>q_@zW|yVK)nG5CD?SDf7b=l|u9DOai&5vH*Hwsinw^zEF(pz{ynd5+r58s1 zEIJ(d1)}CnB-j_hju*6q3Cq9cM->e<*UYexOt7NQ5(Lddf|CfAkc-b=E1OC~;$YOy zWSX6KNN%4jPHSpp=xB8M!BuP&kjbwzl!Q7W9sA2iOOe)!&DoW3WghHl1_$N$|KPYw z4i6@Ws47#a#c6!$k(H5wDTPxZDE}cQVJd-NI$t<)qq$7Q^luR6>=5;~7Y#i}{v;l7 zC`1|9ze1xr*1 l9CLdVa4^K97sE2RX&@tmDo#iwZfZ)bb&Wgk zy{OPuIzj)r|Dh-UZmQf4;hNQw*{Bx4?e)}M0?;y0NG{osYac0t7gE@z#RVR3@*R0< zi?D)T-!DWLooLe8b%mROrx4*>mS!9R`7zfAPi8dk{&}me2NpN#rh6uBuIJbjHipzm zx$vy&9T uU^ H!1ji%O6SHqQptv&E&5nAs4a5fY*8crH+m@ z?DDF_SG0 -OnrJT$1rY7D3Yi}|s`HHyPlQrY**K +tdc=$iTFRLHkMFq{+)M7!t)-RXp7Y2qt!db7jMfF{q`Ta9SW# z-1!@>szBX%)ciNGcQ!>W?y=w*pgUWb;e1g#kn672hBg&Vg=Zvi{FzPx>viLdHsMt< z>9i0Yy7N0FE^)(M!^KOpyje#PEmWrpe)PT`8cGyx4CY$28{sY8kGbedU+jM2MbBA5 zn~a)^*e8ed@vb!C1~6|p)7SWv@=3&Mjqa%$-B8%-$^x;n;NteBQX3&DqSIvAYds+< z=Izn`mhXxwq<(hl{?gZb=2iuhi4~It^Yroz``fX5)Z?jZpZeY`o2-17x?)=GmN5l{ z< omnXzHZ^kx=rXt0{3Na zms9=xnhf!i(aj`UB@ymt pZ8Kpkcm{u)jO8VCk*p3>wLI-VcXdHXx~)|u_ }6$qF2*ABY5Y)c)x}EuYmp44tVnt3w1`-U+5K>#3P (nS$ND%{gyUF*+JOa3@01Xktdb@mnO7xQu@(-yl=n zVnOEhZaw%aBF3T0E}~!QT+|)4932>b<^Q>PR$82bnwE}%`A*S4G-U4Kc4f a~uZF6Z8eNHZ?CaGEN-^;`JR9W^Zt8}8#XLl9)uoEd3?;s$6C*fz||^<$1p zOV6)sqfqBgdHSF||L)mawWQ9E2-gq7QDZ;O`5(An;~%Q6RgiBE>b!-f!LX$(Hb>9; zZS0?V6X;Qx8IIDi&k0C(@uVmR7&n}!X~_EQxN4%m^c{j>UgC(1oSb3R 7Y|eD9V=GmJPX+Eeh@NGDcK;6efr-`3nfY=T~9;YAzfF zMsY$g2OWxw=CWB7H rSbQVEvB8- z2jw!Osvlc5j@UPmLz!gPl|J>}*O2R7?x^jxUBk*0&)vq0Z~mGB1+4JBwMO^D5@oaB z@j2JSW!5Cs?#P}(kmc6KL&sJcxHBJkVk^1q`|ckoUe{-&bb0AMM1H38yhk?Lqmoc3 zsIU*+UdB=wn0`AeOzdfqGU{k+0E459?n&yt=xwa%RQjU9MPSi+bM>Peh!hf%4!e@R z=fd+j7V_B4(6Z_x5oZO?r<>lOJymS6i_1`NQcr(6vJ?o++xMu03yQA7l}D`_CN+TV z)gRV-w#SyN<2ZZa*)3uRGY&M2@PM9&$+nQ=Eb+$OXC^YIc3u}UTh6`kQvEKd=Y2Ls z8JBG5&Cs(5_QC|Zan^j 6&Iotha`d^j4RswhynTi^2B>dZh;i`8&q zV|K~J2m%oVL-Po5GW@H*;Oh0D3Q82T1lm7~IfI*;vosGPN!{1YGH98Jlbut&q53LU z6e;O|QYiQBf@QH(j}Pe2Z=WE>PyC|whEnrgQOe^q`3HLRum&!jM863n0_%SM9DS;% z|4Qb0U??+K7@c_KJo~sT*o_lB@e@hd&J+Qz6})0;M(d7rMsx%?z|JsLH1Ek48q)#? z<+vjb;c}vVPS6QkREO7&+kl7D>%J=mue19G>Mmh~DNUX1o55 4e}9z3jVKv*Ut$>?;14J;6inhPnK;r z6H!q7bC^;Wt4J+&`OyY*64M%EwN>3;>l{+pYWg*sv<24jI-FQW*ul_Z5upk9Uj@<) zSb5uRpwbU`afhqSeT?*a7+4rn&->`AKuc3EUv5OAh-Zj|?dM0#7E1A*BL#bQf>~ZD zj1O>fau^y@G2&MtkAti+*Hb>@r|D be|j^w$|NG;PEh7hLmlOhOV2BEh|jEDGU-{PU@|Nr6G6v$N>`;U9#vNl_O3 z(KV9hR*R?3bp8XrGa~p>Uwp{Pdc`{-#Ud=^!KsXtsE6;Cju&)`i-m`jfZhAkCY>G^ zddKnl4@~fyAE<`oXDgYR5`hXQgXqu>)Amg#+kvhph2J#H8saaeKxs=V{v?}3K^?qxZ_@vw9g`$-acl~FUWW>Kp+5r|G_!iZjk(c->`{Aql| z8n@>R*4!7@x6jS>J!&5eGFiE`L&-_?X?yZK)jnx1nWYxFG@-XjTfPiSmV)ju3guir z|7Il|!8=&S(E|#;CNoH1;~&s$BFphE>}l8tk`}!EtPk`!lI7fKhcn6(E9|X;;p1k9 ze8S<4vT4NZV2nbJWCiBJX>6kNGnzx(Z1O;_ZB!DC3Vs1P?7i`aN8!SXK?FoiWl zVMXT+Sk;HmsYUmm9$270Pf{|`@AceCbFCIN>UsBcis?#KxP^g3u)EP@(CcsAR~NxC z1L1tt*wV64nWl=*dB0e%7YEi4voF#URk}(fcTem-&J`#rwMU3!@bc!rD^F$X9F}hB zFJl_;<1|FK1Y@Z4T4K?Lz}P5NE1>!BW5LteMN1bccjO4Ju^j%0agX{fAT&9}QIGx~ z)K`x^Wn^xF&8|L>dfnc0Tihz>M
2xmNBVwGJvud29IJ2{!e*{#57ssG;XC(V+)d+yhwEkkzi6)U77T# z1}{g}X2#rNEvgf)h+@W9%w3Zhx}B0qydaT{wo92%X3NcsPzM?zb4f(R3xsxCYUn;V zjx)oRy8S{1%l%zBYeHtwWbCc2n4;Z8%pH!{0*jJKJ~+{NfZnd(H*7a9Kk$~Aol7R* zw4z5d{wNA28V;7vFkFAqR*8F?6w}Wg!4=9K5aQjyL(^XGLve|V{9V1SDdG1UmpQzq z6NFKlBiS>g-`g#{@%^aD@Y#$5GTjL)osOpT`gP{*7y4Gzl>D{tchLsdhbx*#$wYxG zg#`Vr=`<+#=z-|ly?IGW;GK@CFM6Lgo6L9{kanT9m+DdH%ZTTz#r~^}4q{C?OGaD8 zK7))r^jP1^*zHgT)qp!UwQ!E?`~o1wht-JlC3q$OQC;}qX{&OKT0b9>4tid TQmlpfGFOLaPR8d&uBV@vJOj)+%C5XAJ<0*k|+bo z?#Z}G3UZ6))C77#m4wea77@~MHE1S>4@F#JC)sN&0WH4lUHhuW#3+}#oJ9_9j70g# zXlZt#W$!qk(mUnPT2l=m3VG+wgsa6{w=bwDC+%Kd4_gIWLBTFU`#j`N`(N2Qi$UZ@ zn@2d 6`v)MJSogjK}r(oiU|t%%iiI`}HVN%?{ZCLD|1&?~#mz{4fqU$C$~hLX=OH z9crjWvx)?>*x+;0WWqIS=YytUJs(dscD#z~+;Rd*F!q!n&wAq|7+8!iWUBNY`7jlZ zf=f&UW=ySW YlwpU6&-e%6hDyArFPrSh8VUEKB7Sm9o$#6C^0O?2@*tLrGr?ad zq2HleusTB9a3p`aSGW*_DGaEEaa~4{SO!s@4A*@2R#?}z@Seosc5=1i7}|G*g4Nc_ z$*bpOQ(-XdMjq}~hL%zirfbvlZ+^VB%xL5rdB;t(N|AoPWv 8m!@E^DZ 4+uMPsjg QX z$uhb@U}Rw_poloxmL=-o>~XQ`Z+9x7*_?~?^k$)aq7340=|Qj2az@Es|B)z4qm2 ztJuw|U=D6Np2Zy4_Uo7&e%acY=w4;GC@d^@bHpGAV~w90f^>AHayL=)Om%9!i~Ebq ziD7l=-|Q+ExQvbiR**$6b;zV*wZh=ZQvYc5NX@0!O@ueVeCb{6GuhX2k&LCW^rFGP z6iqPpX#6C7zO~9r0Cl$cT#ik>b$Yk|OIwQ{r)y>229xD=2)o-~5-)5FfCkrOft?!! z``8%6+RS{i 6QXN+_W`yb#S^nfZ&KJYEObUQTyxVFihK zE^{iuK~jmjUp|mW!)`uEeX?4FgP`dqGCtH$kT)d6RT=?krRk%yhKQgtw>($uD~m~Q z^CSLmVyY{bWyYE>+WvdZhNLTw@Mj{&hS^;H{5X%S3~X0jM%U$mFlfFTZBz|Wf$g|i zpjo jf^IfVB|Ubhk0|tcA@;%5zX%FvQw2xTd{C{xOty1C!}Mg-awU2n z>z~Yr@55- hg2<;yAbPkC$E~N6GlbXB7nzdNdP1C~*e*rWn2qRFt>x(W2a* z30` Ms8sUE`CVo@?%yB3&r%s5bSMKa0tBT%qb4x_t`TIr0Nr!tC zW1stqQJTb|9V;BzL{}(w@NWI-(Baq#Q5)wm9HCk0b)Hej`hEvcK_Nmyr-zo7WA^C? ziu lLCt^f5sU>SL8xyIu|y>F_iPpm_Vy-LG6vB;*smN6-TYHuTBz3d zT-`*9!$5hXV<}nz)3Er- *ga4GG_EjF70uW^B5PG} z(ssvRF%K^L%VNCXZ*G1KTVMjgCfgXM=>0^;`|=UYS~vO71fhj8+Mp z-ZG!$b!xs?{^C!bR6qHC4yQFb@Jkg18lL8hY-E%Guj=+YG~;)KF)lu@y?@?(x;dQ0 z|BJSFj?V1q_I*3HZQHh;bdpXwHaoU$+qP}nHaoU$=cfDj?!DiA&mQNTamTp-WUVCi zRIOT7PtB^DpZd=K1WV55y4lSBA}s&kCC~mh;$r`3Q~seL^urCqu>M`f3<9FR%Wh&} z0fkPh9lq! XIByG;3y7Avr1GB`>8JraWHQY8*7E;qT%eudY9 z_a7bZzXjCL0Y&-z*Kg)Fe`M+YCBRlY@KEoY|6fvR^|A2or^K|cgQ0VHD)RJ}8SUZa%2~niO<>Nc hYhJ zR?}DcdY*dt=Xqd(a6gU@BwND#7} cBe1AcA`FE{tJ+y(D3^LWEd=?b-&t` z?yu0EM=o+E)lA(v`dMOBkHqFT=eAe; BA+0TuSizJ6|`HZeq3px z;$8tG#@;|_caH|? bat_4{vxj8|j{sGML`0dWO6_F5fGSP-Mu3~DFu+r7%_AHObuGXKg-b;=dVh$g3D zWrncvAoX!uA2&l)K6i}LTWj0;vsI>ZLoInPmxJ9&!3MR7=~Ut>{}|v2dc9|QI-Hi% zu!fprDhp;p`j)v8q6@a!44N(H;eT4w%w}br}%Rmr-+Pi5;u2jZaPw-$SPdkA)D< zzkN9`dv!Zpe_paf2%pZ$N2rzlN=w<#<1}>7^tm-X1NA3UU{J+$zKiA{<36f?4`1P> zjDIu3lFn@dlBvt_J1RXk*LsGK9{z(cR7}I&bZz*zje!%atvxP6RCmyZa}LLYuz!no z_lOiOxZ4=%90|evId>%ZY3CVEibGT^T&)+RFrULf%kGf%MWl6Pe2DnCY0n-xcZVJB zknE TO3IDfCTlyZpNYlS`!NB&9|U~DY({6PGkmMlWWae2v{+5weAB?kMQ@b zj$7}rou{3rOGtchc?AV<5xH)yYASL^{een3{RmB@fn;OzVg}Q<)f$nA!HozIHz*}x z5P+I&Uwq$LuC?{osSO;((Hzpo=L5UIjEszw442!&vvDlgpj*RkGI(}u*u%K~Z|9)) zel7mdsLc@XsZM^YBRU3fhQ!6oX4I~wUCyLI&U&xDGe;UjfQc}ky{h-QATG-; Bia_s;H{<6u#SIw?8A>e4WicbNI%6;2tsB0H$|eu@qqG_CxqfR(b>08r$nz zlzhK9j7ku%;`y{oN)soHEjLFEdYmhp7|Jc>QS~yC4;WcrO)Sa@yOVHQ;6w52gcA28 zN(f*~lBZA )FpW<|IbHLD0Q|TV{q{p#`Ty5SHy*yw9 ziLAoyZq4=-D{nyGi=B7-J$khExJf^z<#)OU$JMR$4i+-Jw0#} sVop6{YH0GN5%p6&Wq)v~^8 dbfDNnxGsLYy*$#-D{>dEGIg@oxu_jtvk_7Wqh%G{ z2bG>~mFC%l>7<|CL5e1w4ea!UXn@LA&%uTkkpORh1ZS)vpj@yBtej-&9t;;}G7lK3 z%-CMGBYn_lfqc|*to8_VmTq;A(q*Z~zx^HPkYwvVD%};}L3TIerlYz4iuN6I2Wb|6 zc2Uf6) zl7~YQU(X;O+D_rK&iB_sy+#Ic3B$WHwMQD6q#yW;kKnA9%iaxCo&&OdT-lDhm)68N zkghbGzUxS(HS{dbQMsepVjKp@ijiln!U13Iz269qE9TARn}C%UOfzLZEaT_c4CeO) z)R*$91IGdPW^@Nyxpb5Pi*`*t6=^u(1$yW#m>~cW_1WTgeop)^Zc 8Su9c z_}O;>5WI|9p&z-T(W=Ag1lqY6sQ+?KyZW*pLANA5|1uG8o4x(cB!bpqj$GLc{+f%MEe7k8Kl zbah1VHQixz#psV4fno?#HrI!kqX4zkH?Xi-R|4D!v51Jv^kNYzy}c+o@eMGehe!f8 zy8(_o;h3ASfuTmr4YqW>3supM9&|Ek4+37RxzOtvd5?`^!Oortd`mwHv16QIcl~9# zgJYwkqZ4cuSK?KRVb05iLv_qGLf@stWn`3KNB5#0bcXmw=*w5>rywxf18z0lx9 zZm76hE~C1QX1bLplj|Rt-!flxv&`(`6$?dfAIZz1nQ6sPeHrQO?d2E`xk72C&2-$7 zgZIH7TSJF9z_xP{$qkeBImI+E4W;Yx5oeR0c<0Chq%%_KW!ijNEYx+^Na?^I#7Ghb zw~BfSC7&hRJi&2uAc4)#q8cmI^BP{3%%q4*IoY!t 9xBk!b)Df7$}YqmKGMcn{*pu3TG8W zUvWih8g@Y} PTL&7ISSy7&zd4IkTzezx>mz zdPWw#rf_F;BdmD23XYnEA)b!GTcU*B`7%8yBu}&Eq7lJ#B|#5J{r+9hrS4d|%?%5H z5dj)H0&@6vhcHlL7eS@j97Kuw6{l+=Baq=;t!IOHYOG_<&EV!@zgRhsld>jXPLNtJ z44y*nzy=T#tV;DI-0cnY5n&+*$pz(tC@m*&SH=7$QqIdk8UJBh-#-pvrPXC?t;Ozl z7nHJf-zjQ4?PsLSi{#X0$Yaa}H#x5URIl_^s~x0UW&7xoY)$X(9~*}Tmc_sIW%6XB z=i9`#uLlJ_^w8PT{nyD7vm-*1BQCRcrZ1#U-0%pFWvz8$i7+EO1J!;i={=Wk8=J`O z@W5n;W#==IMw*5r9y?tRjs6rDZ}sm`F3Zi5a*4nkydwjnAE_Y_)h-L5pi0tTyh^P8 zk#+pZO8n*L$N&$ywn>c&+XTAY=h+bA2)(<@9>aylNX5Uz9^%R~c_B_bYS}=_R42_2 zNI^$XUee3|sSo|TlJO!rHg*tnOU@ZL{na4WE8;Q!LL=CMVbFs#^_X|?;LF9zdj*3x z02TE$7_)0*yZ y{7+*FLUAvP ~-JY_Dcf!SNp+NNTLw%uwy~d?y>2_2#}KT=LeYHo0hGk$3I)7 zg}x_noz4gj!)gId)dsP9uAxy7YNE%FZUjk7jEMh=qCyUJpH22^ejcLoO-xL@F|$xJ zmHd!RyznW}Hx=Ugp#Jsl3={!Y@8|QPq&tj@J7lwo9BdC0onbF1Z_D_q+%ucW)cAIz zhVF6TNHdd9l8PRTf4d$Kk%Tl~2?#m7Uz$Lg|KXJrou1HfB1-@VmFIjJBZ1r0L^;$x zjNtwoDM~jAKqx)CRyO%vLCZ#76OFjfMT12q3y=UzS$7G@;Q&KPY?DN++A|tMm_Zid zPBJo9T0PeXm*65FOXCsv@B !xt%Xx~zw zM>Wf=)n}vt(gOjtl)?rnH*8Sfw!6d?1n}3*GjNQ`O1J!85DHLN58)IzEASZU!;5V> zPZD^=s1Utfe8)-xL6lIs4ti89JYJH%^+;u$zZ?V>qM<=?dOZE(7cW+i7`6PA&XA8U zVl4(bJEuqsc+<>Y>>DLExfO7j0D4IdGBWCnt_Or(;}au}>c1YgZo|=yr~E+My9d*x zLJ2+-<-UFBh2kW*7hRN3n9)KqG+2keN&M+NKyIN6>QMtN@R{D{w GPA|Z~TrJ5K+bbj`p8&Co1e>bo2KnGPP$oKZ733D`zr+e;wk?g_zt!maVxU&uC zwrNAxg;4CM;QZ;fn*|uHQBVLkGO5ZR9witON+^ xlwE{e&Rv;m@5sF*Xu~M$Mm}{UUZM=hq1^`c@#bAG`Jn$uYh-Kx_H6fFy^gp9s zQrHQz*x&Y-&(|!~PNAmBC(j#!x z)qaJ>V{|pwC)@#l?VJgw);J4@p00T6l}~Z1sC0rT7${>EApLD~9mL}mxu}$~ P(S z#84YZ-e(m{CZ<#Bx9#x 8?zJp3TJx#9-*AsU^etG z061%1O?wRS{pA%n8dJrRw3s#}+KQ7pgZm58 n!5 zP;%~cYrG!0UcbMTv$tl8Zqc4r3>y;Tn&lX7Ci-jCm0-BuANss)ZoIFYwxk1Dn%l)9 zO^ENVO4d-7mugH>KK0~2Q@oD#LkQu-zbngrjN)cR$rfwDcRXRgT$bf!g5XjABmQa! zvjWdpOn%?(NKWKXm-*B?>(_Jl-<2k6%oUtj``;;~L`*pTCVYH}A)L@~%+(pK WB9n+7`UMzC@B9X=;R0UP9N`5K{wTQP0wurf{@p{qDB<)chj1c` zP7;QI%xq