From 29fb454c2e2e78dcfdd7850b86534b9ae1af91b8 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Fri, 29 Apr 2022 22:22:28 +0800 Subject: [PATCH] modify fastnlp_tutorial_0.py --- tutorials/fastnlp_tutorial_0.ipynb | 1009 ++++++++++++++++++++++++++++ 1 file changed, 1009 insertions(+) create mode 100644 tutorials/fastnlp_tutorial_0.ipynb diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..01913ac0 --- /dev/null +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1009 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + " 1 trainer 和 evaluator 的基本关系\n", + " \n", + " 1.1 trainer 和 evaluater 的初始化\n", + "\n", + " 1.2 driver 的含义与使用要求\n", + "\n", + " 1.3 trainer 内部初始化 evaluater\n", + "\n", + " 2 使用 trainer 训练模型\n", + "\n", + " 2.1 argmax 模型实例\n", + "\n", + " 2.2 trainer 的参数匹配\n", + "\n", + " 2.3 trainer 的实际使用 \n", + "\n", + " 3 使用 evaluator 评测模型\n", + " \n", + " 3.1 trainer 外部初始化的 evaluator\n", + "\n", + " 3.2 trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n", + "\n", + " 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + " 非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver=\"torch\",\n", + "\tdevice=0,\n", + "\t...\n", + ")\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} \n", + " ...\n", + " driver=trainer.driver,\n", + "\tdevice=None,\n", + " ...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + " 例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n", + "\n", + " 具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", + "\n", + "注:在同一脚本中,`Trainer`和`Evaluator`使用的`driver`应当保持一致\n", + "\n", + " 一个不能违背的原则在于:**不要将多卡的`driver`前使用单卡的`driver`**(???),这样使用可能会带来很多意想不到的错误。" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 0.8`中,如果在**初始化`Trainer`时**,**传入参数`evaluator_dataloaders`和`metrics`**\n", + "\n", + " 则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver=\"torch\",\n", + "\tdevice=0,\n", + "\t...\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + "\t...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. 使用 trainer 训练模型" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 argmax 模型实例\n", + "\n", + "本节将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + " 使用`pytorch`定义`argmax`模型,输入一组固定维度的向量,输出其中数值最大的数的索引\n", + "\n", + " 除了添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " super(ArgMaxModel, self).__init__()\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.ac1(self.linear1(x))\n", + " x = self.ac2(self.linear2(x))\n", + " x = self.output(x)\n", + " return x\n", + "\n", + " def train_step(self, x, y):\n", + " x = self(x)\n", + " return {\"loss\": self.loss_fn(x, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " x = self(x)\n", + " x = torch.max(x, dim=-1)[1]\n", + " return {\"pred\": x, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "ca897322", + "metadata": {}, + "source": [ + "在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n", + "\n", + " 由于,在`Trainer`训练时,**`Trainer`通过参数`_train_fn_`对应的模型方法获得当前数据批次的损失值**\n", + "\n", + " 因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + " 如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 0.8`中,`Trainer`要求模型通过`train_step`来返回一个字典,将损失值作为`loss`的键值\n", + "\n", + " 此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现高度化的定制,具体请见这一note(???)\n", + "\n", + "同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n", + "\n", + " 在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + " 从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 的参数匹配\n", + "\n", + "`fastNLP 0.8`中的参数匹配涉及到两个方面,一是在模型训练或者评测的前向传播过程中,如果从`dataloader`中出来一个`batch`的数据是一个字典,那么我们会查看模型的`train_step`和`evaluate_step`方法的参数签名,然后对于每一个参数,我们会根据其名字从 batch 这一字典中选择出对应的数据传入进去。例如在接下来的定义`Dataset`的部分,注意`ArgMaxDatset`的`__getitem__`方法,您可以通过在`Trainer`和`Evaluator`中设置参数 `model_wo_auto_param_call`来关闭这一行为。当您关闭了这一行为后,我们会将`batch`直接传给您的`train_step`、`evaluate_step`或者 `forward`函数。\n", + "\n", + "二是在传入`Trainer`或者`Evaluator metrics`后,我们会在需要评测的时间点主动调用`metrics`来对`evaluate_dataloaders`进行评测,这一功能主要就是通过对`metrics`的`update`方法和一个`batch`的数据进行参数评测实现的。首先需要明确的是一个 metric 的计算通常分为 `update` 和 `get_metric`两步,其中`update`表示更新一个`batch`的评测数据,`get_metric` 表示根据已经得到的评测数据计算出最终的评测值,例如对于 `Accuracy`来说,其在`update`的时候会更新一个`batch`计算正确的数量 right_num 和计算错误的数量 total_num,最终在 `get_metric` 时返回评测值`right_num / total_num`。\n", + "\n", + "因为`fastNLP 0.8`的`metrics`是自动计算的(只需要传给`Trainer`或者`Evaluator`),因此其一定依赖于参数匹配。对于从`evaluate_dataloader`中生成的一个`batch`的数据,我们会查看传给 `Trainer`(最终是传给`Evaluator`)和`Evaluator`的每一个`metric`,然后查看其`update`函数的函数签名,然后根据每一个参数的名字从`batch`字典中选择出对应的数据传入进去。" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 trainer的实际使用\n", + "\n", + "接下来我们创建用于训练的 dataset,其接受三个参数:数据维度、数据量和随机数种子,生成指定数量的维度为 `feature_dimension` 向量,而每一个向量的标签就是该向量中最大值的索引。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDatset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + "现在准备好数据和模型。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataset = ArgMaxDatset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDatset(feature_dimension=10, data_num=100)\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)\n", + "\n", + "# num_labels 设置为 10,与 feature_dimension 保持一致,因为我们是预测十个位置中哪一个的概率最大。\n", + "model = ArgMaxModel(num_labels=10, feature_dimension=10)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + "将优化器也定义好。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "4f1fba81", + "metadata": {}, + "source": [ + "现在万事俱备,开始使用 Trainer 进行训练!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "['__annotations__',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_check_callback_called_legality',\n", + " '_check_train_batch_loop_legality',\n", + " '_custom_callbacks',\n", + " '_driver',\n", + " '_evaluate_dataloaders',\n", + " '_fetch_matched_fn_callbacks',\n", + " '_set_num_eval_batch_per_dl',\n", + " '_train_batch_loop',\n", + " '_train_dataloader',\n", + " '_train_step',\n", + " '_train_step_signature_fn',\n", + " 'accumulation_steps',\n", + " 'add_callback_fn',\n", + " 'backward',\n", + " 'batch_idx_in_epoch',\n", + " 'batch_step_fn',\n", + " 'callback_manager',\n", + " 'check_batch_step_fn',\n", + " 'cur_epoch_idx',\n", + " 'data_device',\n", + " 'dataloader',\n", + " 'device',\n", + " 'driver',\n", + " 'driver_name',\n", + " 'epoch_validate',\n", + " 'evaluate_batch_step_fn',\n", + " 'evaluate_dataloaders',\n", + " 'evaluate_every',\n", + " 'evaluate_fn',\n", + " 'evaluator',\n", + " 'extract_loss_from_outputs',\n", + " 'fp16',\n", + " 'get_no_sync_context',\n", + " 'global_forward_batches',\n", + " 'has_checked_train_batch_loop',\n", + " 'input_mapping',\n", + " 'kwargs',\n", + " 'larger_better',\n", + " 'load',\n", + " 'load_model',\n", + " 'marker',\n", + " 'metrics',\n", + " 'model',\n", + " 'model_device',\n", + " 'monitor',\n", + " 'move_data_to_device',\n", + " 'n_epochs',\n", + " 'num_batches_per_epoch',\n", + " 'on',\n", + " 'on_after_backward',\n", + " 'on_after_optimizers_step',\n", + " 'on_after_trainer_initialized',\n", + " 'on_after_zero_grad',\n", + " 'on_before_backward',\n", + " 'on_before_optimizers_step',\n", + " 'on_before_zero_grad',\n", + " 'on_exception',\n", + " 'on_fetch_data_begin',\n", + " 'on_fetch_data_end',\n", + " 'on_load_checkpoint',\n", + " 'on_load_model',\n", + " 'on_sanity_check_begin',\n", + " 'on_sanity_check_end',\n", + " 'on_save_checkpoint',\n", + " 'on_save_model',\n", + " 'on_train_batch_begin',\n", + " 'on_train_batch_end',\n", + " 'on_train_begin',\n", + " 'on_train_end',\n", + " 'on_train_epoch_begin',\n", + " 'on_train_epoch_end',\n", + " 'on_validate_begin',\n", + " 'on_validate_end',\n", + " 'optimizers',\n", + " 'output_mapping',\n", + " 'run',\n", + " 'save',\n", + " 'save_model',\n", + " 'set_grad_to_none',\n", + " 'state',\n", + " 'step',\n", + " 'step_validate',\n", + " 'total_batches',\n", + " 'train_batch_loop',\n", + " 'train_dataloader',\n", + " 'train_fn',\n", + " 'train_step',\n", + " 'trainer_state',\n", + " 'zero_grad']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "# 定义一个 Trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\", # 使用 pytorch 进行训练\n", + " device=0, # 使用 GPU:0\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 训练 40 个 epoch\n", + " progress_bar=\"rich\"\n", + ")\n", + "dir(trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f8fe9c32", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FullArgSpec(args=['self', 'num_train_batch_per_epoch', 'num_eval_batch_per_dl', 'num_eval_sanity_batch', 'resume_from', 'resume_training', 'catch_KeyboardInterrupt'], varargs=None, varkw=None, defaults=(-1, -1, 2, None, True, None), kwonlyargs=[], kwonlydefaults=None, annotations={'num_train_batch_per_epoch':
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 evaluator 评测模型" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "模型训练好了我们开始使用 Evaluator 进行评测,查看效果怎么样吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "257061df", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['__annotations__',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_dist_sampler',\n", + " '_evaluate_batch_loop',\n", + " '_evaluate_step',\n", + " '_evaluate_step_signature_fn',\n", + " '_metric_wrapper',\n", + " '_metrics',\n", + " 'dataloaders',\n", + " 'device',\n", + " 'driver',\n", + " 'evaluate_batch_loop',\n", + " 'evaluate_batch_step_fn',\n", + " 'evaluate_fn',\n", + " 'evaluate_step',\n", + " 'finally_progress_bar',\n", + " 'get_dataloader_metric',\n", + " 'input_mapping',\n", + " 'metrics',\n", + " 'metrics_wrapper',\n", + " 'model',\n", + " 'model_use_eval_mode',\n", + " 'move_data_to_device',\n", + " 'output_mapping',\n", + " 'progress_bar',\n", + " 'remove_progress_bar',\n", + " 'reset',\n", + " 'run',\n", + " 'separator',\n", + " 'start_progress_bar',\n", + " 'update',\n", + " 'update_progress_bar',\n", + " 'verbose']" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(evaluator)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.3}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.3\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.3}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "## 4. 在 trainer 中加入 metric 来自动评测;" + ] + }, + { + "cell_type": "markdown", + "id": "ca97c9a4", + "metadata": {}, + "source": [ + "现在我们尝试在训练过程中进行评测。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "# 重新定义一个 Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为我们是在同一脚本中,因此这里的 driver 同样需要重用;\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 训练 40 个 epoch;\n", + " evaluate_every=-1, # 表示每一个 epoch 的结束会进行 evaluate;\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "再次训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eabda5eb", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a310d157", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.5}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.5\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.5}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ef78f0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}