From fd91dc373481942d2871b8a4b0159d34424e4405 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Thu, 2 Jun 2022 22:43:48 +0800 Subject: [PATCH] update tutorial-045 lxr 220602 --- tutorials/fastnlp_tutorial_0.ipynb | 271 +--- tutorials/fastnlp_tutorial_4.ipynb | 1966 +++++++++++++++++++++++--- tutorials/fastnlp_tutorial_5.ipynb | 2042 +++++----------------------- 3 files changed, 2146 insertions(+), 2133 deletions(-) diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb index 8312353b..2e315d73 100644 --- a/tutorials/fastnlp_tutorial_0.ipynb +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -50,24 +50,24 @@ "\n", "```python\n", "trainer = Trainer(\n", - " model=model, # 模型基于 torch.nn.Module\n", - " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", - " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", - "\t...\n", - "\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n", - "\tdevice='cuda', # 使用 GPU:0 显卡执行训练\n", - "\t...\n", - ")\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", "...\n", "evaluator = Evaluator(\n", - " model=model, # 模型基于 torch.nn.Module\n", - " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", - " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", - " ...\n", - " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", - "\tdevice=None,\n", - " ...\n", - ")\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", "```" ] }, @@ -84,7 +84,7 @@ "\n", "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n", "\n", - "  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", + "  具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", "\n", "注:这里给出一条建议:**在同一脚本中**,**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n", "\n", @@ -106,17 +106,17 @@ "\n", "```python\n", "trainer = Trainer(\n", - " model=model,\n", - " train_dataloader=train_dataloader,\n", - " optimizers=optimizer,\n", - "\t...\n", - "\tdriver=\"torch\",\n", - "\tdevice='cuda',\n", - "\t...\n", - " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", - " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", - "\t...\n", - ")\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", "```" ] }, @@ -570,7 +570,7 @@ "outputs": [], "source": [ "from fastNLP import Evaluator\n", - "from fastNLP.core.metrics import Accuracy\n", + "from fastNLP import Accuracy\n", "\n", "evaluator = Evaluator(\n", " model=model,\n", @@ -1310,219 +1310,6 @@ "trainer.evaluator.run()" ] }, - { - "cell_type": "code", - "execution_count": 13, - "id": "db784d5b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['__annotations__',\n", - " '__class__',\n", - " '__delattr__',\n", - " '__dict__',\n", - " '__dir__',\n", - " '__doc__',\n", - " '__eq__',\n", - " '__format__',\n", - " '__ge__',\n", - " '__getattribute__',\n", - " '__gt__',\n", - " '__hash__',\n", - " '__init__',\n", - " '__init_subclass__',\n", - " '__le__',\n", - " '__lt__',\n", - " '__module__',\n", - " '__ne__',\n", - " '__new__',\n", - " '__reduce__',\n", - " '__reduce_ex__',\n", - " '__repr__',\n", - " '__setattr__',\n", - " '__sizeof__',\n", - " '__str__',\n", - " '__subclasshook__',\n", - " '__weakref__',\n", - " '_check_callback_called_legality',\n", - " '_check_train_batch_loop_legality',\n", - " '_custom_callbacks',\n", - " '_driver',\n", - " '_evaluate_dataloaders',\n", - " '_fetch_matched_fn_callbacks',\n", - " '_set_num_eval_batch_per_dl',\n", - " '_train_batch_loop',\n", - " '_train_dataloader',\n", - " '_train_step',\n", - " '_train_step_signature_fn',\n", - " 'accumulation_steps',\n", - " 'add_callback_fn',\n", - " 'backward',\n", - " 'batch_idx_in_epoch',\n", - " 'batch_step_fn',\n", - " 'callback_manager',\n", - " 'check_batch_step_fn',\n", - " 'cur_epoch_idx',\n", - " 'data_device',\n", - " 'dataloader',\n", - " 'device',\n", - " 'driver',\n", - " 'driver_name',\n", - " 'epoch_evaluate',\n", - " 'evaluate_batch_step_fn',\n", - " 'evaluate_dataloaders',\n", - " 'evaluate_every',\n", - " 'evaluate_fn',\n", - " 'evaluator',\n", - " 'extract_loss_from_outputs',\n", - " 'fp16',\n", - " 'get_no_sync_context',\n", - " 'global_forward_batches',\n", - " 'has_checked_train_batch_loop',\n", - " 'input_mapping',\n", - " 'kwargs',\n", - " 'larger_better',\n", - " 'load_checkpoint',\n", - " 'load_model',\n", - " 'marker',\n", - " 'metrics',\n", - " 'model',\n", - " 'model_device',\n", - " 'monitor',\n", - " 'move_data_to_device',\n", - " 'n_epochs',\n", - " 'num_batches_per_epoch',\n", - " 'on',\n", - " 'on_after_backward',\n", - " 'on_after_optimizers_step',\n", - " 'on_after_trainer_initialized',\n", - " 'on_after_zero_grad',\n", - " 'on_before_backward',\n", - " 'on_before_optimizers_step',\n", - " 'on_before_zero_grad',\n", - " 'on_evaluate_begin',\n", - " 'on_evaluate_end',\n", - " 'on_exception',\n", - " 'on_fetch_data_begin',\n", - " 'on_fetch_data_end',\n", - " 'on_load_checkpoint',\n", - " 'on_load_model',\n", - " 'on_sanity_check_begin',\n", - " 'on_sanity_check_end',\n", - " 'on_save_checkpoint',\n", - " 'on_save_model',\n", - " 'on_train_batch_begin',\n", - " 'on_train_batch_end',\n", - " 'on_train_begin',\n", - " 'on_train_end',\n", - " 'on_train_epoch_begin',\n", - " 'on_train_epoch_end',\n", - " 'optimizers',\n", - " 'output_mapping',\n", - " 'progress_bar',\n", - " 'run',\n", - " 'run_evaluate',\n", - " 'save_checkpoint',\n", - " 'save_model',\n", - " 'start_batch_idx_in_epoch',\n", - " 'state',\n", - " 'step',\n", - " 'step_evaluate',\n", - " 'total_batches',\n", - " 'train_batch_loop',\n", - " 'train_dataloader',\n", - " 'train_fn',\n", - " 'train_step',\n", - " 'trainer_state',\n", - " 'zero_grad']" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dir(trainer)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "953533c4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Help on method run in module fastNLP.core.controllers.trainer:\n", - "\n", - "run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n", - " 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n", - " \n", - " 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback``\n", - " 去保存断点重训的文件;\n", - " \n", - " :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n", - " :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n", - " :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n", - " :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n", - " :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,\n", - " 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n", - " 其余状态都是保持初始化时的状态不会改变;\n", - " :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n", - " ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n", - " 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;\n", - " \n", - " .. warning::\n", - " \n", - " 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n", - " ``trainer.cur_epoch_idx == trainer.n_epochs``;\n", - " \n", - " 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;\n", - " \n", - " .. note::\n", - " \n", - " 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n", - " 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将\n", - " 该值设定为 **-1** 开始真正的训练;\n", - " \n", - " ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n", - " 整体的验证流程是否正确;\n", - " \n", - " ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n", - " 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的\n", - " 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n", - " 进行验证时会验证的 batch 的数量。\n", - " \n", - " 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n", - " 应当为一个很小的正整数,例如 2;\n", - " \n", - " .. note::\n", - " \n", - " 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n", - " \n", - " 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n", - " 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n", - " \n", - " fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n", - " \n", - " 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;\n", - " ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n", - " 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;\n", - " \n", - " 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n", - "\n" - ] - } - ], - "source": [ - "help(trainer.run)" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/tutorials/fastnlp_tutorial_4.ipynb b/tutorials/fastnlp_tutorial_4.ipynb index ee5a0c6b..10098891 100644 --- a/tutorials/fastnlp_tutorial_4.ipynb +++ b/tutorials/fastnlp_tutorial_4.ipynb @@ -5,292 +5,1931 @@ "id": "fdd7ff16", "metadata": {}, "source": [ - "# T4. trainer 和 evaluator 的深入介绍\n", + "# T4. fastNLP 中的预定义模型\n", "\n", - "  1   fastNLP 中的更多 metric 类型\n", - "\n", - "    1.1   预定义的 metric 类型\n", + "  1   fastNLP 中 modules 的介绍\n", + " \n", + "    1.1   modules 模块、models 模块 简介\n", "\n", - "    1.2   自定义的 metric 类型\n", + "    1.2   示例一:modules 实现 LSTM 分类\n", "\n", - "  2   fastNLP 中 trainer 的补充介绍\n", + "  2   fastNLP 中 models 的介绍\n", " \n", - "    2.1   trainer 的提出构想 \n", + "    2.1   示例一:models 实现 CNN 分类\n", "\n", - "    2.2   trainer 的内部结构\n", + "    2.3   示例二:models 实现 BiLSTM 标注" + ] + }, + { + "cell_type": "markdown", + "id": "d3d65d53", + "metadata": {}, + "source": [ + "## 1. fastNLP 中 modules 模块的介绍\n", "\n", - "    2.3   实例:\n", + "### 1.1 modules 模块、models 模块 简介\n", "\n", - "  3   fastNLP 中的 driver 与 device\n", + "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n", "\n", - "    3.1   driver 的提出构想\n", + "    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n", "\n", - "    3.2   device 与多卡训练" + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n", + "| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n", + "| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n", + "| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n", + "| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |" ] }, { "cell_type": "markdown", - "id": "8d19220c", + "id": "89ffcf07", "metadata": {}, "source": [ - "## 1. fastNLP 中的更多 metric 类型\n", + "  **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n", + "\n", + "    例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n", "\n", - "### 1.1 预定义的 metric 类型\n", + "    基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n", "\n", - "在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评价标准`metric`**\n", + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n", + "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n", + "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n", + "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n", + "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n", + "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n", + "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", + "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", + "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "61318354", + "metadata": {}, + "source": [ + "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n", "\n", - "  包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", + "  同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n", "\n", - "    **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括**召回率`Pre`**、**精确率`Rec`**\n", + "  在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n", "\n", - "    **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n", + "注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n", "\n", - "|
代码名称
|
简要介绍
|
代码路径
|\n", - "|:--|:--|:--|\n", - "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", - "| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", - "| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", - "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", - "| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" + "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n", + "\n", + "  语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n", + "\n", + "  数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)" ] }, { "cell_type": "markdown", - "id": "fdc083a3", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "id": "2a36bbe4", + "metadata": {}, "source": [ - "大概的描述一下,给出各个正确率的计算公式" + "### 1.2 示例一:modules 实现 LSTM 分类" ] }, { "cell_type": "code", - "execution_count": null, - "id": "9775ea5e", + "execution_count": 1, + "id": "40e66b21", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# import sys\n", + "# sys.path.append('..')\n", + "\n", + "# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n", + "\n", + "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n", + "\n", + "# dataset = databundle.get_dataset('train')[:6000]\n", + "\n", + "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + "# progress_bar=\"tqdm\")\n", + "# dataset.delete_field('sentence')\n", + "# dataset.delete_field('label')\n", + "# dataset.delete_field('idx')\n", + "\n", + "# from fastNLP import Vocabulary\n", + "\n", + "# vocab = Vocabulary()\n", + "# vocab.from_dataset(dataset, field_name='words')\n", + "# vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] }, { - "cell_type": "markdown", - "id": "8a22f522", + "cell_type": "code", + "execution_count": 2, + "id": "50960476", "metadata": {}, + "outputs": [], "source": [ - "### 2.2 自定义的 metric 类型\n", + "# from fastNLP import prepare_torch_dataloader\n", "\n", - "在`fastNLP 0.8`中,  给一个案例,训练部分留到trainer部分" + "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, { "cell_type": "code", - "execution_count": null, - "id": "d8caba1d", + "execution_count": 3, + "id": "0b25b25c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# import torch\n", + "# import torch.nn as nn\n", + "\n", + "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n", + "# from fastNLP import Embedding, CrossEntropyLoss\n", + "\n", + "\n", + "# class ClsByModules(nn.Module):\n", + "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + "# nn.Module.__init__(self)\n", + "\n", + "# self.embedding = Embedding((vocab_size, embedding_dim))\n", + "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", + "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", + " \n", + "# self.loss_fn = CrossEntropyLoss()\n", + "\n", + "# def forward(self, words):\n", + "# output = self.embedding(words)\n", + "# output, (hidden, cell) = self.lstm(output)\n", + "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", + "# return output\n", + " \n", + "# def train_step(self, words, target):\n", + "# pred = self(words)\n", + "# return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + "# def evaluate_step(self, words, target):\n", + "# pred = self(words)\n", + "# pred = torch.max(pred, dim=-1)[1]\n", + "# return {\"pred\": pred, \"target\": target}" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "4e6247dd", + "execution_count": 4, + "id": "9dbbf50d", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n", + "\n", + "# from torch.optim import AdamW\n", + "\n", + "# optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] }, { - "cell_type": "markdown", - "id": "08752c5a", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "cell_type": "code", + "execution_count": 5, + "id": "7a93432f", + "metadata": {}, + "outputs": [], "source": [ - "## 2. fastNLP 中 trainer 的补充介绍\n", - "\n", - "### 2.1 trainer 的提出构想\n", + "# from fastNLP import Trainer, Accuracy\n", "\n", - "在`fastNLP 0.8`中,  " + "# trainer = Trainer(\n", + "# model=model,\n", + "# driver='torch',\n", + "# device=0, # 'cuda'\n", + "# n_epochs=10,\n", + "# optimizers=optimizers,\n", + "# train_dataloader=train_dataloader,\n", + "# evaluate_dataloaders=evaluate_dataloader,\n", + "# metrics={'acc': Accuracy()}\n", + "# )" ] }, { "cell_type": "code", - "execution_count": null, - "id": "977a6355", + "execution_count": 6, + "id": "31102e0f", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# trainer.run(num_eval_batch_per_dl=10)" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "69203cdc", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 7, + "id": "8bc4bfb2", + "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# trainer.evaluator.run()" + ] }, { "cell_type": "markdown", - "id": "ab1cea7d", + "id": "d9443213", "metadata": {}, "source": [ - "### 2.2 trainer 的内部结构\n", + "## 2. fastNLP 中 models 模块的介绍\n", + "\n", + "### 2.1 示例一:models 实现 CNN 分类\n", + "\n", + "  本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", "\n", - "在`fastNLP 0.8`中,  \n", + "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n", "\n", - "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", - "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", - "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", - "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", - "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", - "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", - "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", - "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", - "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", - "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", - "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", - "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", - "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", - "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", - "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", - "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", - "'trainer_state', 'zero_grad'\n", + "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", "\n", - "  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" + "    **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n", + "\n", + "  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n", + "\n", + "```\n", + "CNNText(\n", + " (embed): Embedding(\n", + " (embed): Embedding(5194, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (conv_pool): ConvMaxpool(\n", + " (convs): ModuleList(\n", + " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", + " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", + " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=120, out_features=2, bias=True)\n", + ")\n", + "```\n", + "\n", + "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n", + "\n", + "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" ] }, { "cell_type": "code", - "execution_count": null, - "id": "b3c8342e", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 8, + "id": "1aa5cf6d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70cde65067c64fdba1d5e798e2b8d631", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.575,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 92.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 121.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.78125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.80625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 129.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.825,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 132.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "7c811257",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "c1a2e2ca",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "342"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "del dataset\n",
+    "del sst2data\n",
+    "\n",
+    "gc.collect()"
+   ]
   },
   {
    "cell_type": "markdown",
-   "id": "175d6ebb",
+   "id": "6aec2a19",
    "metadata": {},
    "source": [
-    "## 3. fastNLP 中的 driver 与 device\n",
+    "### 2.2  示例二:models 实现 BiLSTM 标注\n",
+    "\n",
+    "  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+    "\n",
+    "    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+    "\n",
+    "  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
     "\n",
-    "### 3.1  driver 的提出构想\n",
+    "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
     "\n",
-    "在`fastNLP 0.8`中,  "
+    "  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+    "\n",
+    "```\n",
+    "BiLSTMCRF(\n",
+    "  (embed): Embedding(7590, 100)\n",
+    "  (lstm): LSTM(\n",
+    "    (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=200, out_features=9, bias=True)\n",
+    "  (crf): ConditionalRandomField()\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+    "\n",
+    "  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "id": "47100e7a",
-   "metadata": {
-    "pycharm": {
-     "name": "#%%\n"
+   "execution_count": 16,
+   "id": "03e66686",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3ec9e0ce9a054339a2453420c2c9f28b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/3 [00:00[17:49:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.220374,\n",
+       "  \"pre#F1\": 0.25,\n",
+       "  \"rec#F1\": 0.197026\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.442857,\n",
+       "  \"pre#F1\": 0.426117,\n",
+       "  \"rec#F1\": 0.460967\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.572954,\n",
+       "  \"pre#F1\": 0.549488,\n",
+       "  \"rec#F1\": 0.598513\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.665399,\n",
+       "  \"pre#F1\": 0.680934,\n",
+       "  \"rec#F1\": 0.650558\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.734694,\n",
+       "  \"pre#F1\": 0.733333,\n",
+       "  \"rec#F1\": 0.736059\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.742647,\n",
+       "  \"pre#F1\": 0.734545,\n",
+       "  \"rec#F1\": 0.750929\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.773585,\n",
+       "  \"pre#F1\": 0.785441,\n",
+       "  \"rec#F1\": 0.762082\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.770115,\n",
+       "  \"pre#F1\": 0.794466,\n",
+       "  \"rec#F1\": 0.747212\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.7603,\n",
+       "  \"pre#F1\": 0.766038,\n",
+       "  \"rec#F1\": 0.754647\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.743682,\n",
+       "  \"pre#F1\": 0.722807,\n",
+       "  \"rec#F1\": 0.765799\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } - }, + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "96bae094",
+   "metadata": {},
    "outputs": [],
    "source": []
   }
@@ -312,15 +1951,6 @@
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.7.13"
-  },
-  "pycharm": {
-   "stem_cell": {
-    "cell_type": "raw",
-    "metadata": {
-     "collapsed": false
-    },
-    "source": []
-   }
   }
  },
  "nbformat": 4,
diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb
index cb105c89..0669a60a 100644
--- a/tutorials/fastnlp_tutorial_5.ipynb
+++ b/tutorials/fastnlp_tutorial_5.ipynb
@@ -5,1313 +5,448 @@
    "id": "fdd7ff16",
    "metadata": {},
    "source": [
-    "# T5. fastNLP 中的预定义模型\n",
+    "# T5. trainer 和 evaluator 的深入介绍\n",
     "\n",
-    "  1   fastNLP 中 modules 的介绍\n",
+    "  1   fastNLP 中 driver 的补充介绍\n",
     " \n",
-    "    1.1   modules 模块、models 模块 简介\n",
+    "    1.1   trainer 和 driver 的构想 \n",
     "\n",
-    "    1.2   示例一:modules 实现 LSTM 分类\n",
+    "    1.2   device 与 多卡训练\n",
     "\n",
-    "  2   fastNLP 中 models 的介绍\n",
-    " \n",
-    "    2.1   示例一:models 实现 CNN 分类\n",
+    "  2   fastNLP 中的更多 metric 类型\n",
+    "\n",
+    "    2.1   预定义的 metric 类型\n",
+    "\n",
+    "    2.2   自定义的 metric 类型\n",
     "\n",
-    "    2.3   示例二:models 实现 BiLSTM 标注"
+    "  3   fastNLP 中 trainer 的补充介绍\n",
+    "\n",
+    "    3.1   trainer 的内部结构"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "d3d65d53",
-   "metadata": {},
+   "id": "08752c5a",
+   "metadata": {
+    "pycharm": {
+     "name": "#%% md\n"
+    }
+   },
    "source": [
-    "## 1. fastNLP 中 modules 模块的介绍\n",
+    "## 1. fastNLP 中 driver 的补充介绍\n",
     "\n",
-    "### 1.1  modules 模块、models 模块 简介\n",
+    "### 1.1  trainer 和 driver 的构想\n",
     "\n",
-    "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n",
+    "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n",
     "\n",
-    "    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n",
+    "  在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n",
     "\n",
-    "| 
代码名称
|
简要介绍
|
代码路径
|\n", + "    **`evaluator`封装评测的`metric`**,**`trainer`封装训练的`optimizer`**,**也可以包括`evaluator`**\n", + "\n", + "之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架**,**例如`pytorch`、`paddle`、`jittor`的兼容**\n", + "\n", + "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", + "\n", + "    划分为**框架无关的循环控制、批量分发部分**,**由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n", + "\n", + "    以及**随框架不同的模型调用、数值优化部分**,**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n", + "\n", + "|
训练过程
|
框架无关 对应`trainer`
|
框架相关 对应`driver`
|\n", "|:--|:--|:--|\n", - "| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n", - "| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n", - "| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", - "| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", - "| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n", - "| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n", - "| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n", - "| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n", - "| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n", - "| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n", - "| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", - "| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", - "| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n", - "| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |" + "|
try:
|
try:
| |\n", + "|
for epoch in 1:n_eoochs:
|
for epoch in 1:n_eoochs:
| |\n", + "|
for step in 1:total_steps:
|
for step in 1:total_steps:
| |\n", + "|
batch = fetch_batch()
|
batch = fetch_batch()
| |\n", + "|
loss = model.forward(batch) 
| |
loss = model.forward(batch) 
|\n", + "|
loss.backward()
| |
loss.backward()
|\n", + "|
model.clear_grad()
| |
model.clear_grad()
|\n", + "|
model.update()
| |
model.update()
|\n", + "|
if need_save:
|
if need_save:
| |\n", + "|
model.save()
| |
model.save()
|\n", + "|
except:
|
except:
| |\n", + "|
process_exception()
|
process_exception()
| |" ] }, { "cell_type": "markdown", - "id": "89ffcf07", + "id": "3e55f07b", "metadata": {}, "source": [ - "  **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n", + "  对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n", "\n", - "    例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n", + "    划分为**框架无关的循环控制、分发汇总部分**,**由`evaluator`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n", "\n", - "    基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n", + "    以及**随框架不同的模型调用、评测计算部分**,同样**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n", "\n", - "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|
评测过程
|
框架无关 对应`evaluator`
|
框架相关 对应`driver`
|\n", "|:--|:--|:--|\n", - "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n", - "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n", - "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n", - "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n", - "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n", - "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n", - "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", - "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", - "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |" + "|
try:
|
try:
| |\n", + "|
model.set_eval()
|
model.set_eval()
| |\n", + "|
for step in 1:total_steps:
|
for step in 1:total_steps:
| |\n", + "|
batch = fetch_batch()
|
batch = fetch_batch()
| |\n", + "|
outputs = model.evaluate(batch) 
| |
outputs = model.evaluate(batch) 
|\n", + "|
metric.compute(batch, outputs)
| |
metric.compute(batch, outputs)
|\n", + "|
results = metric.get_metric()
|
results = metric.get_metric()
| |\n", + "|
except:
|
except:
| |\n", + "|
process_exception()
|
process_exception()
| |" ] }, { "cell_type": "markdown", - "id": "61318354", - "metadata": {}, + "id": "94ba11c6", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "source": [ - "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n", - "\n", - "  同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n", + "由此,从程序员的角度,`fastNLP v0.8`**通过一个`driver`让基于`pytorch`、`paddle`、`jittor`框架的模型**\n", "\n", - "  在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n", + "    **都能在相同的`trainer`和`evaluator`上运行**,这也**是`fastNLP v0.8`相比于之前版本的一大亮点**\n", "\n", - "注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n", + "  而从`driver`的角度,`fastNLP v0.8`通过定义一个`driver`基类,**将所有张量转化为`numpy.tensor`**\n", "\n", - "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "    并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n", "\n", - "注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n", - "\n", - "  语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n", - "\n", - "  数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)" + "    对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`" ] }, { "cell_type": "markdown", - "id": "2a36bbe4", + "id": "ab1cea7d", "metadata": {}, "source": [ - "### 1.2 示例一:modules 实现 LSTM 分类" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "40e66b21", - "metadata": {}, - "outputs": [], - "source": [ - "# import sys\n", - "# sys.path.append('..')\n", + "### 1.2 device 与 多卡训练\n", "\n", - "# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n", + "**`fastNLP v0.8`支持多卡训练**,实现方法则是**通过将`trainer`中的`device`设置为对应显卡的序号列表**\n", "\n", - "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n", + "  由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v0.8`保证:\n", "\n", - "# dataset = databundle.get_dataset('train')[:6000]\n", + "    数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n", "\n", - "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", - "# progress_bar=\"tqdm\")\n", - "# dataset.delete_field('sentence')\n", - "# dataset.delete_field('label')\n", - "# dataset.delete_field('idx')\n", + "    模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n", "\n", - "# from fastNLP import Vocabulary\n", + "  例如,在评测计算运行`get_metric`函数时,`fastNLP v0.8`将自动按照`self.right`和`self.total`\n", "\n", - "# vocab = Vocabulary()\n", - "# vocab.from_dataset(dataset, field_name='words')\n", - "# vocab.index_dataset(dataset, field_name='words')\n", + "    指定的**`aggregate_method`方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n", "\n", - "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + "    在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n", + " \n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # model 基于 pytorch 实现 \n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver='torch', # driver 使用 torch_driver \n", + " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " ...\n", + " )\n", + "\n", + "class Accuracy(Metric):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.register_element(name='total', value=0, aggregate_method='sum')\n", + " self.register_element(name='right', value=0, aggregate_method='sum')\n", + "```\n" ] }, { - "cell_type": "code", - "execution_count": 2, - "id": "50960476", - "metadata": {}, - "outputs": [], + "cell_type": "markdown", + "id": "e2e0a210", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "source": [ - "# from fastNLP import prepare_torch_dataloader\n", - "\n", - "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + "注:`fastNLP v0.8`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示" ] }, { - "cell_type": "code", - "execution_count": 3, - "id": "0b25b25c", + "cell_type": "markdown", + "id": "8d19220c", "metadata": {}, - "outputs": [], "source": [ - "# import torch\n", - "# import torch.nn as nn\n", + "## 2. fastNLP 中的更多 metric 类型\n", "\n", - "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n", - "# from fastNLP import Embedding, CrossEntropyLoss\n", + "### 2.1 预定义的 metric 类型\n", "\n", + "在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评测标准`metric`**\n", "\n", - "# class ClsByModules(nn.Module):\n", - "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", - "# nn.Module.__init__(self)\n", + "  包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", "\n", - "# self.embedding = Embedding((vocab_size, embedding_dim))\n", - "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", - "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", - " \n", - "# self.loss_fn = CrossEntropyLoss()\n", + "    **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括召回率`Pre`、精确率`Rec`\n", "\n", - "# def forward(self, words):\n", - "# output = self.embedding(words)\n", - "# output, (hidden, cell) = self.lstm(output)\n", - "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", - "# return output\n", - " \n", - "# def train_step(self, words, target):\n", - "# pred = self(words)\n", - "# return {\"loss\": self.loss_fn(pred, target)}\n", - "\n", - "# def evaluate_step(self, words, target):\n", - "# pred = self(words)\n", - "# pred = torch.max(pred, dim=-1)[1]\n", - "# return {\"pred\": pred, \"target\": target}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9dbbf50d", - "metadata": {}, - "outputs": [], - "source": [ - "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n", - "\n", - "# from torch.optim import AdamW\n", + "    **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n", "\n", - "# optimizers = AdamW(params=model.parameters(), lr=5e-5)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7a93432f", - "metadata": {}, - "outputs": [], - "source": [ - "# from fastNLP import Trainer, Accuracy\n", - "\n", - "# trainer = Trainer(\n", - "# model=model,\n", - "# driver='torch',\n", - "# device=0, # 'cuda'\n", - "# n_epochs=10,\n", - "# optimizers=optimizers,\n", - "# train_dataloader=train_dataloader,\n", - "# evaluate_dataloaders=evaluate_dataloader,\n", - "# metrics={'acc': Accuracy()}\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "31102e0f", - "metadata": {}, - "outputs": [], - "source": [ - "# trainer.run(num_eval_batch_per_dl=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "8bc4bfb2", - "metadata": {}, - "outputs": [], - "source": [ - "# trainer.evaluator.run()" + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", + "| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", + "| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", + "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", + "| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" ] }, { "cell_type": "markdown", - "id": "d9443213", - "metadata": {}, + "id": "fdc083a3", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "source": [ - "## 2. fastNLP 中 models 模块的介绍\n", + "  如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n", "\n", - "### 2.1 示例一:models 实现 CNN 分类\n", + "    **`update`函数更新单个`batch`的统计量**,**`get_metric`函数返回最终结果**,并打印显示\n", "\n", - "  本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", "\n", - "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n", + "### 2.1.1 Accuracy 与 TransformersAccuracy\n", "\n", - "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", + "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n", "\n", - "    **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n", + "  `get_metric`函数打印格式为 **`{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}`**\n", "\n", - "  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n", + "  一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n", "\n", - "```\n", - "CNNText(\n", - " (embed): Embedding(\n", - " (embed): Embedding(5194, 100)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (conv_pool): ConvMaxpool(\n", - " (convs): ModuleList(\n", - " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", - " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", - " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", - " )\n", - " )\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (fc): Linear(in_features=120, out_features=2, bias=True)\n", - ")\n", - "```\n", + "  **`update`函数的参数包括`pred`、`target`、`seq_len`**,**后者用来标记批次中每笔数据的长度**\n", "\n", - "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n", + "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n", "\n", - "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "1aa5cf6d", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", - "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "70cde65067c64fdba1d5e798e2b8d631", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/6000 [00:00[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.575,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 92.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.75625,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 121.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.78125,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 125.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 128.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.79375,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 127.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.80625,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 129.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.825,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 132.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "trainer.run(num_eval_batch_per_dl=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f47a6a35", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}"
-      ]
-     },
-     "execution_count": 14,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "trainer.evaluator.run()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "7c811257",
-   "metadata": {},
-   "source": [
-    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "id": "c1a2e2ca",
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "342"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import gc\n",
+    "      或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
+    "\n",
+    "    最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
+    "\n",
+    "```python\n",
+    "from fastNLP import Vocabulary\n",
+    "from fastNLP import ClassifyFPreRecMetric\n",
     "\n",
-    "del model\n",
-    "del trainer\n",
-    "del dataset\n",
-    "del sst2data\n",
+    "tag_vocab = Vocabulary(padding=None, unknown=None)            # 记录序号与标签之间的映射\n",
+    "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
+    "                        'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
+    "                        'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
+    "                        'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
+    "                        'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ])  # CoNLL-2003 中的 pos_tags\n",
+    "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
     "\n",
-    "gc.collect()"
+    "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab,          \n",
+    "                                ignore_labels=ignore_labels,  # 表示评测/优化中不考虑上述标签的正误/损失\n",
+    "                                only_gross=True,              # 默认为 True 表示输出所有类别的综合统计结果\n",
+    "                                f_type='micro')               # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
+    "metrics = {'F1': FPreRec}\n",
+    "```"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "6aec2a19",
+   "id": "8a22f522",
    "metadata": {},
    "source": [
-    "### 2.2  示例二:models 实现 BiLSTM 标注\n",
+    "### 2.2  自定义的 metric 类型\n",
+    "\n",
+    "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的`metric`类型**\n",
+    "\n",
+    "    也**需要继承自`Metric`类**,同时**内部自定义好`__init__`、`update`和`get_metric`函数**\n",
     "\n",
-    "  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+    "  在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
     "\n",
-    "    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+    "  在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**`update`的参数名**\n",
     "\n",
-    "  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
+    "    **需要待评估模型在`evaluate_step`中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
     "\n",
-    "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
+    "    在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
     "\n",
-    "  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+    "    此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n",
     "\n",
-    "```\n",
-    "BiLSTMCRF(\n",
-    "  (embed): Embedding(7590, 100)\n",
-    "  (lstm): LSTM(\n",
-    "    (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
-    "  )\n",
-    "  (dropout): Dropout(p=0.1, inplace=False)\n",
-    "  (fc): Linear(in_features=200, out_features=9, bias=True)\n",
-    "  (crf): ConditionalRandomField()\n",
-    ")\n",
-    "```\n",
+    "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
     "\n",
-    "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+    "    其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
     "\n",
-    "  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
+    "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
-   "id": "03e66686",
+   "execution_count": null,
+   "id": "08a872e9",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "3ec9e0ce9a054339a2453420c2c9f28b",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "  0%|          | 0/3 [00:00[17:49:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
-       "\n"
-      ],
-      "text/plain": [
-       "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.220374,\n",
-       "  \"pre#F1\": 0.25,\n",
-       "  \"rec#F1\": 0.197026\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.442857,\n",
-       "  \"pre#F1\": 0.426117,\n",
-       "  \"rec#F1\": 0.460967\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.572954,\n",
-       "  \"pre#F1\": 0.549488,\n",
-       "  \"rec#F1\": 0.598513\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.665399,\n",
-       "  \"pre#F1\": 0.680934,\n",
-       "  \"rec#F1\": 0.650558\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.734694,\n",
-       "  \"pre#F1\": 0.733333,\n",
-       "  \"rec#F1\": 0.736059\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.742647,\n",
-       "  \"pre#F1\": 0.734545,\n",
-       "  \"rec#F1\": 0.750929\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.773585,\n",
-       "  \"pre#F1\": 0.785441,\n",
-       "  \"rec#F1\": 0.762082\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.770115,\n",
-       "  \"pre#F1\": 0.794466,\n",
-       "  \"rec#F1\": 0.747212\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.7603,\n",
-       "  \"pre#F1\": 0.766038,\n",
-       "  \"rec#F1\": 0.754647\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.743682,\n",
-       "  \"pre#F1\": 0.722807,\n",
-       "  \"rec#F1\": 0.765799\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "## 3. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 3.1 trainer 的内部结构\n", + "\n", + "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n", + "\n", + "  展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n", + "\n", + "\n", + "\n", + "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", + "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", + "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", + "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", + "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", + "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", + "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", + "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", + "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", + "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", + "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", + "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", + "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", + "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", + "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", + "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", + "'trainer_state', 'zero_grad'\n", + "\n", + "  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "37871d6b", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}"
-      ]
-     },
-     "execution_count": 22,
-     "metadata": {},
-     "output_type": "execute_result"
+   "execution_count": null,
+   "id": "c348864c",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
     }
-   ],
-   "source": [
-    "trainer.evaluator.run()"
-   ]
+   },
+   "outputs": [],
+   "source": []
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "96bae094",
-   "metadata": {},
+   "id": "43be274f",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
    "outputs": [],
    "source": []
   }
@@ -1951,6 +538,15 @@
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
   }
  },
  "nbformat": 4,