diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb index 8312353b..2e315d73 100644 --- a/tutorials/fastnlp_tutorial_0.ipynb +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -50,24 +50,24 @@ "\n", "```python\n", "trainer = Trainer(\n", - " model=model, # 模型基于 torch.nn.Module\n", - " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", - " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", - "\t...\n", - "\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n", - "\tdevice='cuda', # 使用 GPU:0 显卡执行训练\n", - "\t...\n", - ")\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", "...\n", "evaluator = Evaluator(\n", - " model=model, # 模型基于 torch.nn.Module\n", - " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", - " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", - " ...\n", - " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", - "\tdevice=None,\n", - " ...\n", - ")\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", "```" ] }, @@ -84,7 +84,7 @@ "\n", "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n", "\n", - " 具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", + " 具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", "\n", "注:这里给出一条建议:**在同一脚本中**,**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n", "\n", @@ -106,17 +106,17 @@ "\n", "```python\n", "trainer = Trainer(\n", - " model=model,\n", - " train_dataloader=train_dataloader,\n", - " optimizers=optimizer,\n", - "\t...\n", - "\tdriver=\"torch\",\n", - "\tdevice='cuda',\n", - "\t...\n", - " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", - " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", - "\t...\n", - ")\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", "```" ] }, @@ -570,7 +570,7 @@ "outputs": [], "source": [ "from fastNLP import Evaluator\n", - "from fastNLP.core.metrics import Accuracy\n", + "from fastNLP import Accuracy\n", "\n", "evaluator = Evaluator(\n", " model=model,\n", @@ -1310,219 +1310,6 @@ "trainer.evaluator.run()" ] }, - { - "cell_type": "code", - "execution_count": 13, - "id": "db784d5b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__annotations__',\n", - " '__class__',\n", - " '__delattr__',\n", - " '__dict__',\n", - " '__dir__',\n", - " '__doc__',\n", - " '__eq__',\n", - " '__format__',\n", - " '__ge__',\n", - " '__getattribute__',\n", - " '__gt__',\n", - " '__hash__',\n", - " '__init__',\n", - " '__init_subclass__',\n", - " '__le__',\n", - " '__lt__',\n", - " '__module__',\n", - " '__ne__',\n", - " '__new__',\n", - " '__reduce__',\n", - " '__reduce_ex__',\n", - " '__repr__',\n", - " '__setattr__',\n", - " '__sizeof__',\n", - " '__str__',\n", - " '__subclasshook__',\n", - " '__weakref__',\n", - " '_check_callback_called_legality',\n", - " '_check_train_batch_loop_legality',\n", - " '_custom_callbacks',\n", - " '_driver',\n", - " '_evaluate_dataloaders',\n", - " '_fetch_matched_fn_callbacks',\n", - " '_set_num_eval_batch_per_dl',\n", - " '_train_batch_loop',\n", - " '_train_dataloader',\n", - " '_train_step',\n", - " '_train_step_signature_fn',\n", - " 'accumulation_steps',\n", - " 'add_callback_fn',\n", - " 'backward',\n", - " 'batch_idx_in_epoch',\n", - " 'batch_step_fn',\n", - " 'callback_manager',\n", - " 'check_batch_step_fn',\n", - " 'cur_epoch_idx',\n", - " 'data_device',\n", - " 'dataloader',\n", - " 'device',\n", - " 'driver',\n", - " 'driver_name',\n", - " 'epoch_evaluate',\n", - " 'evaluate_batch_step_fn',\n", - " 'evaluate_dataloaders',\n", - " 'evaluate_every',\n", - " 'evaluate_fn',\n", - " 'evaluator',\n", - " 'extract_loss_from_outputs',\n", - " 'fp16',\n", - " 'get_no_sync_context',\n", - " 'global_forward_batches',\n", - " 'has_checked_train_batch_loop',\n", - " 'input_mapping',\n", - " 'kwargs',\n", - " 'larger_better',\n", - " 'load_checkpoint',\n", - " 'load_model',\n", - " 'marker',\n", - " 'metrics',\n", - " 'model',\n", - " 'model_device',\n", - " 'monitor',\n", - " 'move_data_to_device',\n", - " 'n_epochs',\n", - " 'num_batches_per_epoch',\n", - " 'on',\n", - " 'on_after_backward',\n", - " 'on_after_optimizers_step',\n", - " 'on_after_trainer_initialized',\n", - " 'on_after_zero_grad',\n", - " 'on_before_backward',\n", - " 'on_before_optimizers_step',\n", - " 'on_before_zero_grad',\n", - " 'on_evaluate_begin',\n", - " 'on_evaluate_end',\n", - " 'on_exception',\n", - " 'on_fetch_data_begin',\n", - " 'on_fetch_data_end',\n", - " 'on_load_checkpoint',\n", - " 'on_load_model',\n", - " 'on_sanity_check_begin',\n", - " 'on_sanity_check_end',\n", - " 'on_save_checkpoint',\n", - " 'on_save_model',\n", - " 'on_train_batch_begin',\n", - " 'on_train_batch_end',\n", - " 'on_train_begin',\n", - " 'on_train_end',\n", - " 'on_train_epoch_begin',\n", - " 'on_train_epoch_end',\n", - " 'optimizers',\n", - " 'output_mapping',\n", - " 'progress_bar',\n", - " 'run',\n", - " 'run_evaluate',\n", - " 'save_checkpoint',\n", - " 'save_model',\n", - " 'start_batch_idx_in_epoch',\n", - " 'state',\n", - " 'step',\n", - " 'step_evaluate',\n", - " 'total_batches',\n", - " 'train_batch_loop',\n", - " 'train_dataloader',\n", - " 'train_fn',\n", - " 'train_step',\n", - " 'trainer_state',\n", - " 'zero_grad']" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(trainer)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "953533c4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on method run in module fastNLP.core.controllers.trainer:\n", - "\n", - "run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n", - " 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n", - " \n", - " 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback``\n", - " 去保存断点重训的文件;\n", - " \n", - " :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n", - " :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n", - " :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n", - " :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n", - " :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,\n", - " 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n", - " 其余状态都是保持初始化时的状态不会改变;\n", - " :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n", - " ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n", - " 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;\n", - " \n", - " .. warning::\n", - " \n", - " 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n", - " ``trainer.cur_epoch_idx == trainer.n_epochs``;\n", - " \n", - " 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;\n", - " \n", - " .. note::\n", - " \n", - " 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n", - " 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将\n", - " 该值设定为 **-1** 开始真正的训练;\n", - " \n", - " ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n", - " 整体的验证流程是否正确;\n", - " \n", - " ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n", - " 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的\n", - " 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n", - " 进行验证时会验证的 batch 的数量。\n", - " \n", - " 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n", - " 应当为一个很小的正整数,例如 2;\n", - " \n", - " .. note::\n", - " \n", - " 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n", - " \n", - " 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n", - " 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n", - " \n", - " fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n", - " \n", - " 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;\n", - " ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n", - " 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;\n", - " \n", - " 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n", - "\n" - ] - } - ], - "source": [ - "help(trainer.run)" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/tutorials/fastnlp_tutorial_4.ipynb b/tutorials/fastnlp_tutorial_4.ipynb index ee5a0c6b..10098891 100644 --- a/tutorials/fastnlp_tutorial_4.ipynb +++ b/tutorials/fastnlp_tutorial_4.ipynb @@ -5,292 +5,1931 @@ "id": "fdd7ff16", "metadata": {}, "source": [ - "# T4. trainer 和 evaluator 的深入介绍\n", + "# T4. fastNLP 中的预定义模型\n", "\n", - " 1 fastNLP 中的更多 metric 类型\n", - "\n", - " 1.1 预定义的 metric 类型\n", + " 1 fastNLP 中 modules 的介绍\n", + " \n", + " 1.1 modules 模块、models 模块 简介\n", "\n", - " 1.2 自定义的 metric 类型\n", + " 1.2 示例一:modules 实现 LSTM 分类\n", "\n", - " 2 fastNLP 中 trainer 的补充介绍\n", + " 2 fastNLP 中 models 的介绍\n", " \n", - " 2.1 trainer 的提出构想 \n", + " 2.1 示例一:models 实现 CNN 分类\n", "\n", - " 2.2 trainer 的内部结构\n", + " 2.3 示例二:models 实现 BiLSTM 标注" + ] + }, + { + "cell_type": "markdown", + "id": "d3d65d53", + "metadata": {}, + "source": [ + "## 1. fastNLP 中 modules 模块的介绍\n", "\n", - " 2.3 实例:\n", + "### 1.1 modules 模块、models 模块 简介\n", "\n", - " 3 fastNLP 中的 driver 与 device\n", + "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n", "\n", - " 3.1 driver 的提出构想\n", + " 包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n", "\n", - " 3.2 device 与多卡训练" + "|
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } - }, + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "96380c67", + "metadata": {}, + "source": [ + "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b9dd1273", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] }, { "cell_type": "markdown", - "id": "ce6322b4", + "id": "96941b63", "metadata": {}, "source": [ - "### 2.3 实例:\n", + "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n", + "\n", + " 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n", "\n", - "在`fastNLP 0.8`中, " + " **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维" ] }, { "cell_type": "code", - "execution_count": null, - "id": "43be274f", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 11, + "id": "f6e76e2e", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "0cc5ca10", + "metadata": {}, + "source": [ + "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "c348864c", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 12, + "id": "50a13ee5", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "28903a7d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.575,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 92.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 121.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.78125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.80625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 129.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.825,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 132.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "7c811257", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c1a2e2ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "342" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "del dataset\n", + "del sst2data\n", + "\n", + "gc.collect()" + ] }, { "cell_type": "markdown", - "id": "175d6ebb", + "id": "6aec2a19", "metadata": {}, "source": [ - "## 3. fastNLP 中的 driver 与 device\n", + "### 2.2 示例二:models 实现 BiLSTM 标注\n", + "\n", + " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n", + "\n", + " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n", + "\n", + " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n", "\n", - "### 3.1 driver 的提出构想\n", + "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n", "\n", - "在`fastNLP 0.8`中, " + " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n", + "\n", + "```\n", + "BiLSTMCRF(\n", + " (embed): Embedding(7590, 100)\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=200, out_features=9, bias=True)\n", + " (crf): ConditionalRandomField()\n", + ")\n", + "```\n", + "\n", + "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n", + "\n", + " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下" ] }, { "cell_type": "code", - "execution_count": null, - "id": "47100e7a", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 16, + "id": "03e66686", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3ec9e0ce9a054339a2453420c2c9f28b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } - }, - "outputs": [], - "source": [] + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ner2data = load_dataset('conll2003', 'conll2003')" + ] + }, + { + "cell_type": "markdown", + "id": "fc505631", + "metadata": {}, + "source": [ + "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", + "\n", + " 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n", + "\n", + "此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n", + "\n", + " 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n", + "\n", + " **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n", + "\n", + " **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "0204a223", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 17, + "id": "1f88cad4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" } - }, + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('tokens')\n", + "dataset.delete_field('ner_tags')\n", + "dataset.delete_field('pos_tags')\n", + "dataset.delete_field('chunk_tags')\n", + "dataset.delete_field('id')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "token_vocab = Vocabulary()\n", + "token_vocab.from_dataset(dataset, field_name='words')\n", + "token_vocab.index_dataset(dataset, field_name='words')\n", + "label_vocab = Vocabulary(padding=None, unknown=None)\n", + "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "d9889427", + "metadata": {}, + "source": [ + "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "7802a072", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] }, { "cell_type": "markdown", - "id": "6e723b87", + "id": "2bc7831b", "metadata": {}, "source": [ - "### 3.2 device 与多卡训练\n", + "接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n", + "\n", + " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n", "\n", - "在`fastNLP 0.8`中, " + " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`" ] }, { "cell_type": "code", - "execution_count": null, - "id": "5ad81ac7", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 19, + "id": "4e12c09f", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP.models.torch import BiLSTMCRF\n", + "\n", + "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n", + " num_layers=1, hidden_size=150, dropout=0.2)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "bf30608f", + "metadata": {}, + "source": [ + "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n", + "\n", + " **使用`SpanFPreRecMetric`作为`NER`的评价标准**,详细请参考接下来的`tutorial-5`\n", + "\n", + " 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "cfb28b1b", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 20, + "id": "cbd6c205", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, SpanFPreRecMetric\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0f8eff34", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[17:49:16] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.220374,\n", + " \"pre#F1\": 0.25,\n", + " \"rec#F1\": 0.197026\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.442857,\n", + " \"pre#F1\": 0.426117,\n", + " \"rec#F1\": 0.460967\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.572954,\n", + " \"pre#F1\": 0.549488,\n", + " \"rec#F1\": 0.598513\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.665399,\n", + " \"pre#F1\": 0.680934,\n", + " \"rec#F1\": 0.650558\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.734694,\n", + " \"pre#F1\": 0.733333,\n", + " \"rec#F1\": 0.736059\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.742647,\n", + " \"pre#F1\": 0.734545,\n", + " \"rec#F1\": 0.750929\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.773585,\n", + " \"pre#F1\": 0.785441,\n", + " \"rec#F1\": 0.762082\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.770115,\n", + " \"pre#F1\": 0.794466,\n", + " \"rec#F1\": 0.747212\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.7603,\n", + " \"pre#F1\": 0.766038,\n", + " \"rec#F1\": 0.754647\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.743682,\n", + " \"pre#F1\": 0.722807,\n", + " \"rec#F1\": 0.765799\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } - }, + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96bae094", + "metadata": {}, "outputs": [], "source": [] } @@ -312,15 +1951,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb index cb105c89..0669a60a 100644 --- a/tutorials/fastnlp_tutorial_5.ipynb +++ b/tutorials/fastnlp_tutorial_5.ipynb @@ -5,1313 +5,448 @@ "id": "fdd7ff16", "metadata": {}, "source": [ - "# T5. fastNLP 中的预定义模型\n", + "# T5. trainer 和 evaluator 的深入介绍\n", "\n", - " 1 fastNLP 中 modules 的介绍\n", + " 1 fastNLP 中 driver 的补充介绍\n", " \n", - " 1.1 modules 模块、models 模块 简介\n", + " 1.1 trainer 和 driver 的构想 \n", "\n", - " 1.2 示例一:modules 实现 LSTM 分类\n", + " 1.2 device 与 多卡训练\n", "\n", - " 2 fastNLP 中 models 的介绍\n", - " \n", - " 2.1 示例一:models 实现 CNN 分类\n", + " 2 fastNLP 中的更多 metric 类型\n", + "\n", + " 2.1 预定义的 metric 类型\n", + "\n", + " 2.2 自定义的 metric 类型\n", "\n", - " 2.3 示例二:models 实现 BiLSTM 标注" + " 3 fastNLP 中 trainer 的补充介绍\n", + "\n", + " 3.1 trainer 的内部结构" ] }, { "cell_type": "markdown", - "id": "d3d65d53", - "metadata": {}, + "id": "08752c5a", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ - "## 1. fastNLP 中 modules 模块的介绍\n", + "## 1. fastNLP 中 driver 的补充介绍\n", "\n", - "### 1.1 modules 模块、models 模块 简介\n", + "### 1.1 trainer 和 driver 的构想\n", "\n", - "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n", + "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n", "\n", - " 包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n", + " 在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n", "\n", - "|
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/6000 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import sys\n", - "sys.path.append('..')\n", + " `get_metric`函数打印格式为 **`{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}`**\n", "\n", - "from fastNLP import DataSet\n", + " 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n", "\n", - "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n", "\n", - "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", - " progress_bar=\"tqdm\")\n", - "dataset.delete_field('sentence')\n", - "dataset.delete_field('label')\n", - "dataset.delete_field('idx')\n", + "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n", "\n", - "from fastNLP import Vocabulary\n", + " **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n", "\n", - "vocab = Vocabulary()\n", - "vocab.from_dataset(dataset, field_name='words')\n", - "vocab.index_dataset(dataset, field_name='words')\n", + " **`micro F1`**(**直接统计所有类别的`Rec-Pre-F1`**)、**`macro F1`**(**统计各类别的`Rec-Pre-F1`再算术平均**)\n", "\n", - "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" - ] - }, - { - "cell_type": "markdown", - "id": "96380c67", - "metadata": {}, - "source": [ - "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "b9dd1273", - "metadata": {}, - "outputs": [], - "source": [ - "from fastNLP import prepare_torch_dataloader\n", + " **第三**,两者在初始化时还可以**传入基于`fastNLP.Vocabulary`的`tag_vocab`参数记录数据集中的标签序号**\n", "\n", - "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" - ] - }, - { - "cell_type": "markdown", - "id": "96941b63", - "metadata": {}, - "source": [ - "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n", + " **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n", "\n", - " 注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n", + "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n", "\n", - " **`embed`表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100`维" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "f6e76e2e", - "metadata": {}, - "outputs": [], - "source": [ - "from fastNLP.models.torch import CNNText\n", + " **`SpanFPreRecMetric`针对更复杂的抽取问题**,**规定标签`B-xx`和`I-xx`或`B-xx`和`E-xx`构成标签对**\n", "\n", - "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + " 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n", "\n", - "from torch.optim import AdamW\n", + " 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-4)" - ] - }, - { - "cell_type": "markdown", - "id": "0cc5ca10", - "metadata": {}, - "source": [ - "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "50a13ee5", - "metadata": {}, - "outputs": [], - "source": [ - "from fastNLP import Trainer, Accuracy\n", + " 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n", "\n", - "trainer = Trainer(\n", - " model=model,\n", - " driver='torch',\n", - " device=0, # 'cuda'\n", - " n_epochs=10,\n", - " optimizers=optimizers,\n", - " train_dataloader=train_dataloader,\n", - " evaluate_dataloaders=evaluate_dataloader,\n", - " metrics={'acc': Accuracy()}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "28903a7d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.575,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 92.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.75625,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 121.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.78125,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 125.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 128.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.79375,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 127.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.80625,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 129.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.825,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 132.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "trainer.run(num_eval_batch_per_dl=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f47a6a35", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trainer.evaluator.run()" - ] - }, - { - "cell_type": "markdown", - "id": "7c811257", - "metadata": {}, - "source": [ - " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "c1a2e2ca", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "342" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import gc\n", + " 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n", + "\n", + " 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n", + "\n", + "```python\n", + "from fastNLP import Vocabulary\n", + "from fastNLP import ClassifyFPreRecMetric\n", "\n", - "del model\n", - "del trainer\n", - "del dataset\n", - "del sst2data\n", + "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n", + "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n", + " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n", + " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n", + " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n", + " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n", + "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n", "\n", - "gc.collect()" + "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n", + " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n", + " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n", + " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n", + "metrics = {'F1': FPreRec}\n", + "```" ] }, { "cell_type": "markdown", - "id": "6aec2a19", + "id": "8a22f522", "metadata": {}, "source": [ - "### 2.2 示例二:models 实现 BiLSTM 标注\n", + "### 2.2 自定义的 metric 类型\n", + "\n", + "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的`metric`类型**\n", + "\n", + " 也**需要继承自`Metric`类**,同时**内部自定义好`__init__`、`update`和`get_metric`函数**\n", "\n", - " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n", + " 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n", "\n", - " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n", + " 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**`update`的参数名**\n", "\n", - " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n", + " **需要待评估模型在`evaluate_step`中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n", "\n", - "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n", + " 在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", "\n", - " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n", + " 此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n", "\n", - "```\n", - "BiLSTMCRF(\n", - " (embed): Embedding(7590, 100)\n", - " (lstm): LSTM(\n", - " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n", - " )\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (fc): Linear(in_features=200, out_features=9, bias=True)\n", - " (crf): ConditionalRandomField()\n", - ")\n", - "```\n", + " 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", "\n", - "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n", + " 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n", "\n", - " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下" + "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示" ] }, { "cell_type": "code", - "execution_count": 16, - "id": "03e66686", + "execution_count": null, + "id": "08a872e9", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3ec9e0ce9a054339a2453420c2c9f28b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "from datasets import load_dataset\n", + "from fastNLP import Metric\n", + "\n", + "class MyMetric(Metric):\n", + "\n", + " def __init__(self):\n", + " MyMetric.__init__(self)\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + "\n", + " def update(self, pred, true):\n", + " self.total_num += target.size(0)\n", + " self.right_num += target.eq(pred).sum().item()\n", "\n", - "ner2data = load_dataset('conll2003', 'conll2003')" + " def get_metric(self, reset=True):\n", + " acc = self.acc_count / self.total_num\n", + " if reset:\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + " return {'prefix': acc}" ] }, { "cell_type": "markdown", - "id": "fc505631", + "id": "af3f8c63", "metadata": {}, "source": [ - "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", - "\n", - " 完成数据集格式调整、文本序列化等操作;此处**需要`'words'`、`'seq_len'`、`'target'`三个字段**\n", - "\n", - "此外,**需要定义`NER`标签到标签序号的映射**(**词汇表`label_vocab`**),数据集中标签已经完成了序号映射\n", - "\n", - " 所以需要人工定义**`9`个标签对应之前的`9`个分类目标**;数据集说明中规定,`'O'`表示其他标签\n", - "\n", - " **后缀`'-PER'`、`'-ORG'`、`'-LOC'`、`'-MISC'`对应人名、组织名、地名、时间等其他命名**\n", - "\n", - " **前缀`'B-'`表示起始标签、`'I-'`表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签" + " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" ] }, { "cell_type": "code", - "execution_count": 17, - "id": "1f88cad4", + "execution_count": null, + "id": "2fd210c5", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/4000 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", - "from fastNLP import DataSet\n", - "\n", - "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n", - "\n", - "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n", - " progress_bar=\"tqdm\")\n", - "dataset.delete_field('tokens')\n", - "dataset.delete_field('ner_tags')\n", - "dataset.delete_field('pos_tags')\n", - "dataset.delete_field('chunk_tags')\n", - "dataset.delete_field('id')\n", + "from fastNLP.models.torch import CNNText\n", "\n", - "from fastNLP import Vocabulary\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", "\n", - "token_vocab = Vocabulary()\n", - "token_vocab.from_dataset(dataset, field_name='words')\n", - "token_vocab.index_dataset(dataset, field_name='words')\n", - "label_vocab = Vocabulary(padding=None, unknown=None)\n", - "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n", + "from torch.optim import AdamW\n", "\n", - "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" ] }, { "cell_type": "markdown", - "id": "d9889427", + "id": "0155f447", "metadata": {}, "source": [ - "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`" + " 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" ] }, { "cell_type": "code", - "execution_count": 18, - "id": "7802a072", - "metadata": {}, + "execution_count": null, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ - "from fastNLP import prepare_torch_dataloader\n", + "from datasets import load_dataset\n", "\n", - "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + "sst2data = load_dataset('glue', 'sst2')" ] }, { "cell_type": "markdown", - "id": "2bc7831b", + "id": "e9d81760", "metadata": {}, "source": [ - "接着,**从`fastNLP.models.torch`路径下导入`BiLSTMCRF`**,初始化`BiLSTMCRF`实例和优化器\n", - "\n", - " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数`embed`、`num_classes`是必须传入的**\n", + "接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n", "\n", - " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`" + " 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`" ] }, { "cell_type": "code", - "execution_count": 19, - "id": "4e12c09f", - "metadata": {}, + "execution_count": null, + "id": "cfb28b1b", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ - "from fastNLP.models.torch import BiLSTMCRF\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", - "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n", - " num_layers=1, hidden_size=150, dropout=0.2)\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'true': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", "\n", - "from torch.optim import AdamW\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=1e-3)" + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "\n", + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, { "cell_type": "markdown", - "id": "bf30608f", + "id": "1e21df35", "metadata": {}, "source": [ - "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n", - "\n", - " 参考`tutorial-4`中的内容,**使用`SpanFPreRecMetric`作为`NER`的评价标准**\n", + "然后就是初始化`trainer`实例,其中`metrics`变量输入的键值对,字串`'suffix'`和之前定义的字串`'prefix'`\n", "\n", - " 同时,**初始化时需要添加`vocabulary`形式的标签与序号之间的映射`tag_vocab`**" + " 将拼接在一起显示到`trainer`的`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" ] }, { "cell_type": "code", - "execution_count": 20, - "id": "cbd6c205", + "execution_count": null, + "id": "926a9c50", "metadata": {}, "outputs": [], "source": [ - "from fastNLP import Trainer, SpanFPreRecMetric\n", + "from fastNLP import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", @@ -1321,615 +456,67 @@ " optimizers=optimizers,\n", " train_dataloader=train_dataloader,\n", " evaluate_dataloaders=evaluate_dataloader,\n", - " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n", + " metrics={'suffix': MyMetric()}\n", ")" ] }, { - "cell_type": "code", - "execution_count": 21, - "id": "0f8eff34", + "cell_type": "markdown", + "id": "6e723b87", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[17:49:16] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.220374,\n", - " \"pre#F1\": 0.25,\n", - " \"rec#F1\": 0.197026\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.442857,\n", - " \"pre#F1\": 0.426117,\n", - " \"rec#F1\": 0.460967\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.572954,\n", - " \"pre#F1\": 0.549488,\n", - " \"rec#F1\": 0.598513\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.665399,\n", - " \"pre#F1\": 0.680934,\n", - " \"rec#F1\": 0.650558\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.734694,\n", - " \"pre#F1\": 0.733333,\n", - " \"rec#F1\": 0.736059\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.742647,\n", - " \"pre#F1\": 0.734545,\n", - " \"rec#F1\": 0.750929\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.773585,\n", - " \"pre#F1\": 0.785441,\n", - " \"rec#F1\": 0.762082\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.770115,\n", - " \"pre#F1\": 0.794466,\n", - " \"rec#F1\": 0.747212\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.7603,\n", - " \"pre#F1\": 0.766038,\n", - " \"rec#F1\": 0.754647\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.743682,\n", - " \"pre#F1\": 0.722807,\n", - " \"rec#F1\": 0.765799\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "## 3. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 3.1 trainer 的内部结构\n", + "\n", + "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n", + "\n", + " 展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n", + "\n", + "\n", + "\n", + "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", + "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", + "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", + "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", + "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", + "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", + "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", + "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", + "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", + "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", + "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", + "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", + "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", + "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", + "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", + "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", + "'trainer_state', 'zero_grad'\n", + "\n", + " run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "37871d6b", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" + "execution_count": null, + "id": "c348864c", + "metadata": { + "pycharm": { + "name": "#%%\n" } - ], - "source": [ - "trainer.evaluator.run()" - ] + }, + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "96bae094", - "metadata": {}, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [] } @@ -1951,6 +538,15 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4,