From 82b06767f55307143f6ed0e1a1400a8ee339a3cf Mon Sep 17 00:00:00 2001 From: yhcc Date: Fri, 3 Jun 2022 22:35:46 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dn=5Fbatches=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 7c6bba53..41fca6ba 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -448,7 +448,7 @@ class Trainer(TrainerEventTrigger): # 初始化 state,包括提供给用户的接口和我们自己使用的接口; self.state = State() self.trainer_state = TrainerState( - n_epochs=n_epochs if n_batches!=-1 else None, + n_epochs=n_epochs if n_batches==-1 else None, cur_epoch_idx=0, global_forward_batches=0, batch_idx_in_epoch=0, From 3797f91434930d869eff2849d79f2d13e7409642 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Sat, 4 Jun 2022 00:03:40 +0800 Subject: [PATCH 2/2] update tutorial-3456 lxr 220603 --- tutorials/fastnlp_tutorial_2.ipynb | 10 +- tutorials/fastnlp_tutorial_3.ipynb | 226 ++++- tutorials/fastnlp_tutorial_4.ipynb | 1323 +--------------------------- tutorials/fastnlp_tutorial_5.ipynb | 218 +++-- tutorials/fastnlp_tutorial_6.ipynb | 302 ++++++- 5 files changed, 676 insertions(+), 1403 deletions(-) diff --git a/tutorials/fastnlp_tutorial_2.ipynb b/tutorials/fastnlp_tutorial_2.ipynb index 64c4bc8b..33f74b8b 100644 --- a/tutorials/fastnlp_tutorial_2.ipynb +++ b/tutorials/fastnlp_tutorial_2.ipynb @@ -801,24 +801,24 @@ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "\n", "# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n", - "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv'))\n", + "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n", "train_ds, test_ds = datasets.split(ratio=0.7)\n", "data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n", "\n", "# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n", "encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n", " return_attention_mask=True)\n", - "data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n", + "data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n", "\n", "# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n", "target_vocab = Vocabulary(padding=None, unknown=None)\n", - "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n", - "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n", + "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n", + "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n", " new_field_name='target')\n", "\n", "# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n", "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", - "data_bundle.set_ignore('label', 'text') \n", + "data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence') \n", "```" ] }, diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 172e1232..7566c02a 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -9,9 +9,9 @@ "\n", "  1   fastNLP 中的 dataloader\n", " \n", - "    1.1   dataloader 的职责描述\n", + "    1.1   dataloader 的基本介绍\n", "\n", - "    1.2   dataloader 的基本使用\n", + "    1.2   dataloader 的函数创建\n", "\n", "  2   fastNLP 中 dataloader 的延伸\n", "\n", @@ -27,32 +27,143 @@ "source": [ "## 1. fastNLP 中的 dataloader\n", "\n", - "### 1.1 dataloader 的职责描述\n", + "### 1.1 dataloader 的基本介绍\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前" + "在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n", + "\n", + "  **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n", + "\n", + "本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n", + "\n", + "    **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n", + "\n", + "  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n", + "\n", + "    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n", + "\n", + "    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n", + "\n", + "只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n", + "\n", + "    **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n", + "\n", + "    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n", + "\n", + "  针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n", + "\n", + "    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n", + "\n", + "**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n", + "\n", + "    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n", + "\n", + "    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n", + "\n", + "    第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n", + "\n", + "  对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n", + "\n", + "    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n", + "\n", + "|
名称
|
参数
|
属性
|
功能
|
内容
|\n", + "|:--|:--:|:--:|:--|:--|\n", + "| **`dataset`** | √ | √ | 指定`dataloader`的数据内容 | |\n", + "| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", + "| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", + "| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", + "| `sampler` | √ | √ | ? | 默认`None` |\n", + "| `batch_sampler` | √ | √ | ? | 默认`None` |\n", + "| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", + "| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", + "| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", + "| `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n", + "| `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n", + "| `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |" ] }, { "cell_type": "markdown", - "id": "eb8fb51c", + "id": "60a8a224", "metadata": {}, "source": [ - "### 1.2 dataloader 的基本使用\n", + "  论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前," + "    包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n", + "\n", + "    以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "aca72b49", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/6000 [00:00[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.575,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 92.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.75625,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 121.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.78125,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 125.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 128.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.79375,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 127.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.80625,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 129.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.825,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 132.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.81875,\n",
-       "  \"total#acc\": 160.0,\n",
-       "  \"correct#acc\": 131.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.run()" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "f47a6a35", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}"
-      ]
-     },
-     "execution_count": 14,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "trainer.evaluator.run()"
    ]
@@ -1082,21 +451,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": null,
    "id": "c1a2e2ca",
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "342"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "import gc\n",
     "\n",
@@ -1144,32 +502,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": null,
    "id": "03e66686",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "3ec9e0ce9a054339a2453420c2c9f28b",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "  0%|          | 0/3 [00:00[17:49:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
-       "\n"
-      ],
-      "text/plain": [
-       "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.220374,\n",
-       "  \"pre#F1\": 0.25,\n",
-       "  \"rec#F1\": 0.197026\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.442857,\n",
-       "  \"pre#F1\": 0.426117,\n",
-       "  \"rec#F1\": 0.460967\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.572954,\n",
-       "  \"pre#F1\": 0.549488,\n",
-       "  \"rec#F1\": 0.598513\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.665399,\n",
-       "  \"pre#F1\": 0.680934,\n",
-       "  \"rec#F1\": 0.650558\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.734694,\n",
-       "  \"pre#F1\": 0.733333,\n",
-       "  \"rec#F1\": 0.736059\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.742647,\n",
-       "  \"pre#F1\": 0.734545,\n",
-       "  \"rec#F1\": 0.750929\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.773585,\n",
-       "  \"pre#F1\": 0.785441,\n",
-       "  \"rec#F1\": 0.762082\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.770115,\n",
-       "  \"pre#F1\": 0.794466,\n",
-       "  \"rec#F1\": 0.747212\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.7603,\n",
-       "  \"pre#F1\": 0.766038,\n",
-       "  \"rec#F1\": 0.754647\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"f#F1\": 0.743682,\n",
-       "  \"pre#F1\": 0.722807,\n",
-       "  \"rec#F1\": 0.765799\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "37871d6b", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}"
-      ]
-     },
-     "execution_count": 22,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "trainer.evaluator.run()"
    ]
diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb
index 0669a60a..63410113 100644
--- a/tutorials/fastnlp_tutorial_5.ipynb
+++ b/tutorials/fastnlp_tutorial_5.ipynb
@@ -312,6 +312,9 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "import sys\n",
+    "sys.path.append('..')\n",
+    "\n",
     "from fastNLP import Metric\n",
     "\n",
     "class MyMetric(Metric):\n",
@@ -333,33 +336,6 @@
     "        return {'prefix': acc}"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "id": "af3f8c63",
-   "metadata": {},
-   "source": [
-    "  模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "2fd210c5",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import sys\n",
-    "sys.path.append('..')\n",
-    "\n",
-    "from fastNLP.models.torch import CNNText\n",
-    "\n",
-    "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
-    "\n",
-    "from torch.optim import AdamW\n",
-    "\n",
-    "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
-   ]
-  },
   {
    "cell_type": "markdown",
    "id": "0155f447",
@@ -389,9 +365,9 @@
    "id": "e9d81760",
    "metadata": {},
    "source": [
-    "接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
+    "  在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
     "\n",
-    "  对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
+    "    对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
    ]
   },
   {
@@ -429,14 +405,136 @@
     "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "af3f8c63",
+   "metadata": {},
+   "source": [
+    "  模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2fd210c5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP.models.torch import CNNText\n",
+    "\n",
+    "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
+    "\n",
+    "from torch.optim import AdamW\n",
+    "\n",
+    "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6e723b87",
+   "metadata": {},
+   "source": [
+    "## 3. fastNLP 中 trainer 的补充介绍\n",
+    "\n",
+    "### 3.1  trainer 的内部结构\n",
+    "\n",
+    "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
+    "\n",
+    "  很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
+    "\n",
+    "| 
名称
|
参数
|
属性
|
功能
|
内容
|\n", + "|:--|:--:|:--:|:--|:--|\n", + "| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n", + "| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", + "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", + "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n", + "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n", + "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n", + "| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n", + "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n", + "| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n", + "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n", + "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n", + "| **`train_dataloader`** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n", + "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n", + "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n", + "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n", + "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n", + "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n", + "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n", + "| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n", + "| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n", + "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n", + "| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n", + "| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n", + "| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n", + "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |" + ] + }, + { + "cell_type": "markdown", + "id": "2fc8b9f3", + "metadata": {}, + "source": [ + "  以及`trainer`模块内部的基础方法,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n", + "\n", + "|
名称
|
功能
|
主要参数
|\n", + "|:--|:--|:--|\n", + "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n", + "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n", + "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n", + "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n", + "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n", + "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n", + "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n", + "| `save_checkpoint` |
保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler`
和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar`
| `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n", + "| `load_checkpoint` |
加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler`
和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar`
|
`folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True`
`resume_training`指明是否只精确到上次训练的批量,默认`True`
|\n", + "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n", + "| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n", + "\n", + "" + ] + }, { "cell_type": "markdown", "id": "1e21df35", "metadata": {}, "source": [ - "然后就是初始化`trainer`实例,其中`metrics`变量输入的键值对,字串`'suffix'`和之前定义的字串`'prefix'`\n", + "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", "\n", - "  将拼接在一起显示到`trainer`的`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" + "  字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" ] }, { @@ -462,61 +560,43 @@ }, { "cell_type": "markdown", - "id": "6e723b87", - "metadata": {}, + "id": "b1b2e8b7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "source": [ - "## 3. fastNLP 中 trainer 的补充介绍\n", - "\n", - "### 3.1 trainer 的内部结构\n", - "\n", - "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n", - "\n", - "  展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n", - "\n", - "\n", - "\n", - "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", - "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", - "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", - "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", - "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", - "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", - "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", - "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", - "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", - "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", - "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", - "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", - "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", - "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", - "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", - "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", - "'trainer_state', 'zero_grad'\n", + "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", "\n", - "  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" + "|
名称
|
功能
|
默认值
|\n", + "|:--|:--|:--|\n", + "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", + "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", + "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", + "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", + "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" ] }, { "cell_type": "code", "execution_count": null, - "id": "c348864c", + "id": "43be274f", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], - "source": [] + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] }, { "cell_type": "code", "execution_count": null, - "id": "43be274f", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "id": "f1abfa0a", + "metadata": {}, "outputs": [], "source": [] } diff --git a/tutorials/fastnlp_tutorial_6.ipynb b/tutorials/fastnlp_tutorial_6.ipynb index 2052189e..552f73d9 100644 --- a/tutorials/fastnlp_tutorial_6.ipynb +++ b/tutorials/fastnlp_tutorial_6.ipynb @@ -19,20 +19,312 @@ "\n", "    2.2   使用 jittor 搭建并训练模型\n", "\n", - "  3   fastNLP 实现 paddle 与 pytorch 互转\n", + "  3   fastNLP 实现 paddle 与 pytorch 互转" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08752c5a", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "sst2data = load_dataset('glue', 'sst2')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e8cc210", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", "\n", - "    3.1   \n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", "\n", - "    3.2   " + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "print(type(train_dataset), isinstance(train_dataset, DataSet))\n", + "\n", + "from fastNLP.io import DataBundle\n", + "\n", + "data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})" + ] + }, + { + "cell_type": "markdown", + "id": "57a3272f", + "metadata": {}, + "source": [ + "## 1. fastNLP 结合 paddle 训练模型\n", + "\n", + "```python\n", + "import paddle\n", + "\n", + "lstm = paddle.nn.LSTM(16, 32, 2)\n", + "\n", + "x = paddle.randn((4, 23, 16))\n", + "h = paddle.randn((2, 4, 32))\n", + "c = paddle.randn((2, 4, 32))\n", + "\n", + "y, (h, c) = lstm(x, (h, c))\n", + "\n", + "print(y.shape) # [4, 23, 32]\n", + "print(h.shape) # [2, 4, 32]\n", + "print(c.shape) # [2, 4, 32]\n", + "```" ] }, { "cell_type": "code", "execution_count": null, - "id": "08752c5a", + "id": "e31b3198", + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn as nn\n", + "\n", + "\n", + "class ClsByPaddle(nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " nn.Layer.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n", + " # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", + " # num_layers=num_layers, direction='bidirectional', dropout=dropout)\n", + " self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),\n", + " ('activate', nn.ReLU()),\n", + " ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))\n", + " \n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, words):\n", + " output = self.embedding(words)\n", + " # output, (hidden, cell) = self.lstm(output)\n", + " hidden = paddle.randn((2, words.shape[0], self.hidden_dim))\n", + " output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = paddle.max(pred, axis=-1)[1]\n", + " return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c63b030f", + "metadata": {}, + "outputs": [], + "source": [ + "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2997c0aa", + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.optimizer import AdamW\n", + "\n", + "optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ead35fb8", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_paddle_dataloader\n", + "\n", + "# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25e8da83", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='paddle',\n", + " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dl_bundle['train'], # train_dataloader,\n", + " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d63c5d74", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run(num_eval_batch_per_dl=10) # 然后卡了?" + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0b3c", + "metadata": {}, + "source": [ + "## 2. fastNLP 结合 jittor 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c600191d", + "metadata": {}, + "outputs": [], + "source": [ + "import jittor\n", + "import jittor.nn as nn\n", + "\n", + "from jittor import Module\n", + "\n", + "\n", + "class ClsByJittor(Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " Module.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", + " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", + " self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim * 2, output_dim)])\n", + "\n", + " self.loss_fn = nn.BCELoss()\n", + "\n", + " def execute(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = jittor.max(pred, axis=-1)[1]\n", + " return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a94ed8c4", + "metadata": {}, + "outputs": [], + "source": [ + "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d15ebc1", + "metadata": {}, + "outputs": [], + "source": [ + "from jittor.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95d8d09e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_jittor_dataloader\n", + "\n", + "# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "917eab81", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='jittor',\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dl_bundle['train'], # train_dataloader,\n", + " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7c4ac5a", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] } ], "metadata": {