@@ -15,15 +15,15 @@ | |||||
"\n", | "\n", | ||||
"    1.3   trainer 内部初始化 evaluater\n", | "    1.3   trainer 内部初始化 evaluater\n", | ||||
"\n", | "\n", | ||||
"  2   使用 fastNLP 0.8 搭建 argmax 模型\n", | |||||
"  2   使用 fastNLP 搭建 argmax 模型\n", | |||||
"\n", | "\n", | ||||
"    2.1   trainer_step 和 evaluator_step\n", | "    2.1   trainer_step 和 evaluator_step\n", | ||||
"\n", | "\n", | ||||
"    2.2   trainer 和 evaluator 的参数匹配\n", | "    2.2   trainer 和 evaluator 的参数匹配\n", | ||||
"\n", | "\n", | ||||
"    2.3   一个实际案例:argmax 模型\n", | |||||
"    2.3   示例:argmax 模型的搭建\n", | |||||
"\n", | "\n", | ||||
"  3   使用 fastNLP 0.8 训练 argmax 模型\n", | |||||
"  3   使用 fastNLP 训练 argmax 模型\n", | |||||
" \n", | " \n", | ||||
"    3.1   trainer 外部初始化的 evaluator\n", | "    3.1   trainer 外部初始化的 evaluator\n", | ||||
"\n", | "\n", | ||||
@@ -248,7 +248,7 @@ | |||||
"id": "f62b7bb1", | "id": "f62b7bb1", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"### 2.3 一个实际案例:argmax 模型\n", | |||||
"### 2.3 示例:argmax 模型的搭建\n", | |||||
"\n", | "\n", | ||||
"下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", | "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", | ||||
"\n", | "\n", | ||||
@@ -271,7 +271,7 @@ | |||||
"\n", | "\n", | ||||
"class ArgMaxModel(nn.Module):\n", | "class ArgMaxModel(nn.Module):\n", | ||||
" def __init__(self, num_labels, feature_dimension):\n", | " def __init__(self, num_labels, feature_dimension):\n", | ||||
" super(ArgMaxModel, self).__init__()\n", | |||||
" nn.Module.__init__(self)\n", | |||||
" self.num_labels = num_labels\n", | " self.num_labels = num_labels\n", | ||||
"\n", | "\n", | ||||
" self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", | " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", | ||||
@@ -434,7 +434,7 @@ | |||||
"\n", | "\n", | ||||
"  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", | "  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", | ||||
"\n", | "\n", | ||||
"    但对于`\"auto\"`和`\"rich\"`格式,在notebook中,进度条在训练结束后会被丢弃\n", | |||||
"    但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", | |||||
"\n", | "\n", | ||||
"  通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" | "  通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" | ||||
] | ] | ||||
@@ -489,7 +489,7 @@ | |||||
"\n", | "\n", | ||||
"  其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", | "  其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", | ||||
"\n", | "\n", | ||||
"  此外,可以通过`inspect.getfullargspec(trainer.run)`查询`run`函数的全部参数列表" | |||||
"  `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -590,7 +590,7 @@ | |||||
"\n", | "\n", | ||||
"  其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", | "  其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", | ||||
"\n", | "\n", | ||||
"  最终,输出形如`{'acc#acc': acc}`的字典,在notebook中,进度条在评测结束后会被丢弃" | |||||
"  最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -603,6 +603,20 @@ | |||||
} | } | ||||
}, | }, | ||||
"outputs": [ | "outputs": [ | ||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -616,11 +630,11 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.37</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37.0</span><span style=\"font-weight: bold\">}</span>\n", | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'total#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'correct#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span><span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | "</pre>\n" | ||||
], | ], | ||||
"text/plain": [ | "text/plain": [ | ||||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.37\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m37.0\u001b[0m\u001b[1m}\u001b[0m\n" | |||||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" | |||||
] | ] | ||||
}, | }, | ||||
"metadata": {}, | "metadata": {}, | ||||
@@ -629,7 +643,7 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'acc#acc': 0.37, 'total#acc': 100.0, 'correct#acc': 37.0}" | |||||
"{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 9, | "execution_count": 9, | ||||
@@ -650,9 +664,9 @@ | |||||
"\n", | "\n", | ||||
"通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", | "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", | ||||
"\n", | "\n", | ||||
"  通过`progress_bar`同时设定训练和评估进度条格式,在notebook中,在进度条训练结束后会被丢弃\n", | |||||
"  通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", | |||||
"\n", | "\n", | ||||
"  **通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n", | |||||
"  但是中间的评估结果仍会保留;**通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n", | |||||
"\n", | "\n", | ||||
"    **为负数时**,**表示每隔几个`epoch`评估一次**;**为正数时**,**则表示每隔几个`batch`评估一次**" | "    **为负数时**,**表示每隔几个`epoch`评估一次**;**为正数时**,**则表示每隔几个`batch`评估一次**" | ||||
] | ] | ||||
@@ -687,9 +701,9 @@ | |||||
"source": [ | "source": [ | ||||
"通过使用`Trainer`类的`run`函数,进行训练\n", | "通过使用`Trainer`类的`run`函数,进行训练\n", | ||||
"\n", | "\n", | ||||
"  还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测,默认为2\n", | |||||
"  还可以通过**参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测**,**默认为`2`**\n", | |||||
"\n", | "\n", | ||||
"  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此试探性评测" | |||||
"  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -702,6 +716,33 @@ | |||||
} | } | ||||
}, | }, | ||||
"outputs": [ | "outputs": [ | ||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[18:28:25] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -712,6 +753,490 @@ | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "display_data" | "output_type": "display_data" | ||||
}, | }, | ||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.31</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.33</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">33.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.34</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">34.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.36</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">36.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.37</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">37.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">100.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">40.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -746,6 +1271,20 @@ | |||||
"id": "c4e9c619", | "id": "c4e9c619", | ||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -759,7 +1298,7 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'acc#acc': 0.47, 'total#acc': 100.0, 'correct#acc': 47.0}" | |||||
"{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 12, | "execution_count": 12, | ||||
@@ -773,9 +1312,222 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | |||||
"execution_count": 13, | |||||
"id": "db784d5b", | "id": "db784d5b", | ||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | |||||
{ | |||||
"data": { | |||||
"text/plain": [ | |||||
"['__annotations__',\n", | |||||
" '__class__',\n", | |||||
" '__delattr__',\n", | |||||
" '__dict__',\n", | |||||
" '__dir__',\n", | |||||
" '__doc__',\n", | |||||
" '__eq__',\n", | |||||
" '__format__',\n", | |||||
" '__ge__',\n", | |||||
" '__getattribute__',\n", | |||||
" '__gt__',\n", | |||||
" '__hash__',\n", | |||||
" '__init__',\n", | |||||
" '__init_subclass__',\n", | |||||
" '__le__',\n", | |||||
" '__lt__',\n", | |||||
" '__module__',\n", | |||||
" '__ne__',\n", | |||||
" '__new__',\n", | |||||
" '__reduce__',\n", | |||||
" '__reduce_ex__',\n", | |||||
" '__repr__',\n", | |||||
" '__setattr__',\n", | |||||
" '__sizeof__',\n", | |||||
" '__str__',\n", | |||||
" '__subclasshook__',\n", | |||||
" '__weakref__',\n", | |||||
" '_check_callback_called_legality',\n", | |||||
" '_check_train_batch_loop_legality',\n", | |||||
" '_custom_callbacks',\n", | |||||
" '_driver',\n", | |||||
" '_evaluate_dataloaders',\n", | |||||
" '_fetch_matched_fn_callbacks',\n", | |||||
" '_set_num_eval_batch_per_dl',\n", | |||||
" '_train_batch_loop',\n", | |||||
" '_train_dataloader',\n", | |||||
" '_train_step',\n", | |||||
" '_train_step_signature_fn',\n", | |||||
" 'accumulation_steps',\n", | |||||
" 'add_callback_fn',\n", | |||||
" 'backward',\n", | |||||
" 'batch_idx_in_epoch',\n", | |||||
" 'batch_step_fn',\n", | |||||
" 'callback_manager',\n", | |||||
" 'check_batch_step_fn',\n", | |||||
" 'cur_epoch_idx',\n", | |||||
" 'data_device',\n", | |||||
" 'dataloader',\n", | |||||
" 'device',\n", | |||||
" 'driver',\n", | |||||
" 'driver_name',\n", | |||||
" 'epoch_evaluate',\n", | |||||
" 'evaluate_batch_step_fn',\n", | |||||
" 'evaluate_dataloaders',\n", | |||||
" 'evaluate_every',\n", | |||||
" 'evaluate_fn',\n", | |||||
" 'evaluator',\n", | |||||
" 'extract_loss_from_outputs',\n", | |||||
" 'fp16',\n", | |||||
" 'get_no_sync_context',\n", | |||||
" 'global_forward_batches',\n", | |||||
" 'has_checked_train_batch_loop',\n", | |||||
" 'input_mapping',\n", | |||||
" 'kwargs',\n", | |||||
" 'larger_better',\n", | |||||
" 'load_checkpoint',\n", | |||||
" 'load_model',\n", | |||||
" 'marker',\n", | |||||
" 'metrics',\n", | |||||
" 'model',\n", | |||||
" 'model_device',\n", | |||||
" 'monitor',\n", | |||||
" 'move_data_to_device',\n", | |||||
" 'n_epochs',\n", | |||||
" 'num_batches_per_epoch',\n", | |||||
" 'on',\n", | |||||
" 'on_after_backward',\n", | |||||
" 'on_after_optimizers_step',\n", | |||||
" 'on_after_trainer_initialized',\n", | |||||
" 'on_after_zero_grad',\n", | |||||
" 'on_before_backward',\n", | |||||
" 'on_before_optimizers_step',\n", | |||||
" 'on_before_zero_grad',\n", | |||||
" 'on_evaluate_begin',\n", | |||||
" 'on_evaluate_end',\n", | |||||
" 'on_exception',\n", | |||||
" 'on_fetch_data_begin',\n", | |||||
" 'on_fetch_data_end',\n", | |||||
" 'on_load_checkpoint',\n", | |||||
" 'on_load_model',\n", | |||||
" 'on_sanity_check_begin',\n", | |||||
" 'on_sanity_check_end',\n", | |||||
" 'on_save_checkpoint',\n", | |||||
" 'on_save_model',\n", | |||||
" 'on_train_batch_begin',\n", | |||||
" 'on_train_batch_end',\n", | |||||
" 'on_train_begin',\n", | |||||
" 'on_train_end',\n", | |||||
" 'on_train_epoch_begin',\n", | |||||
" 'on_train_epoch_end',\n", | |||||
" 'optimizers',\n", | |||||
" 'output_mapping',\n", | |||||
" 'progress_bar',\n", | |||||
" 'run',\n", | |||||
" 'run_evaluate',\n", | |||||
" 'save_checkpoint',\n", | |||||
" 'save_model',\n", | |||||
" 'start_batch_idx_in_epoch',\n", | |||||
" 'state',\n", | |||||
" 'step',\n", | |||||
" 'step_evaluate',\n", | |||||
" 'total_batches',\n", | |||||
" 'train_batch_loop',\n", | |||||
" 'train_dataloader',\n", | |||||
" 'train_fn',\n", | |||||
" 'train_step',\n", | |||||
" 'trainer_state',\n", | |||||
" 'zero_grad']" | |||||
] | |||||
}, | |||||
"execution_count": 13, | |||||
"metadata": {}, | |||||
"output_type": "execute_result" | |||||
} | |||||
], | |||||
"source": [ | |||||
"dir(trainer)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 14, | |||||
"id": "953533c4", | |||||
"metadata": {}, | |||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Help on method run in module fastNLP.core.controllers.trainer:\n", | |||||
"\n", | |||||
"run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n", | |||||
" 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n", | |||||
" \n", | |||||
" 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback``\n", | |||||
" 去保存断点重训的文件;\n", | |||||
" \n", | |||||
" :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n", | |||||
" :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n", | |||||
" :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n", | |||||
" :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n", | |||||
" :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,\n", | |||||
" 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n", | |||||
" 其余状态都是保持初始化时的状态不会改变;\n", | |||||
" :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n", | |||||
" ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n", | |||||
" 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;\n", | |||||
" \n", | |||||
" .. warning::\n", | |||||
" \n", | |||||
" 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n", | |||||
" ``trainer.cur_epoch_idx == trainer.n_epochs``;\n", | |||||
" \n", | |||||
" 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;\n", | |||||
" \n", | |||||
" .. note::\n", | |||||
" \n", | |||||
" 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n", | |||||
" 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将\n", | |||||
" 该值设定为 **-1** 开始真正的训练;\n", | |||||
" \n", | |||||
" ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n", | |||||
" 整体的验证流程是否正确;\n", | |||||
" \n", | |||||
" ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n", | |||||
" 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的\n", | |||||
" 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n", | |||||
" 进行验证时会验证的 batch 的数量。\n", | |||||
" \n", | |||||
" 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n", | |||||
" 应当为一个很小的正整数,例如 2;\n", | |||||
" \n", | |||||
" .. note::\n", | |||||
" \n", | |||||
" 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n", | |||||
" \n", | |||||
" 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n", | |||||
" 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n", | |||||
" \n", | |||||
" fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n", | |||||
" \n", | |||||
" 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;\n", | |||||
" ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n", | |||||
" 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;\n", | |||||
" \n", | |||||
" 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n", | |||||
"\n" | |||||
] | |||||
} | |||||
], | |||||
"source": [ | |||||
"help(trainer.run)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "1bc7cb4a", | |||||
"metadata": {}, | |||||
"outputs": [], | "outputs": [], | ||||
"source": [] | "source": [] | ||||
} | } | ||||
@@ -281,13 +281,13 @@ | |||||
"## 2. fastNLP 中的 tokenizer\n", | "## 2. fastNLP 中的 tokenizer\n", | ||||
"\n", | "\n", | ||||
"### 2.1 PreTrainTokenizer 的提出\n", | "### 2.1 PreTrainTokenizer 的提出\n", | ||||
"\n", | |||||
"<!-- \n", | |||||
"*词嵌入是什么,为什么不用了*\n", | "*词嵌入是什么,为什么不用了*\n", | ||||
"\n", | "\n", | ||||
"*什么是字节对编码,BPE的提出*\n", | "*什么是字节对编码,BPE的提出*\n", | ||||
"\n", | "\n", | ||||
"*以BERT模型为例,WordPiece的提出*\n", | "*以BERT模型为例,WordPiece的提出*\n", | ||||
"\n", | |||||
" -->\n", | |||||
"在`fastNLP 0.8`中,**使用`PreTrainedTokenizer`模块来为数据集中的词语进行词向量的标注**\n", | "在`fastNLP 0.8`中,**使用`PreTrainedTokenizer`模块来为数据集中的词语进行词向量的标注**\n", | ||||
"\n", | "\n", | ||||
"  需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了`transformers`模块**\n", | "  需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了`transformers`模块**\n", | ||||
@@ -5,32 +5,292 @@ | |||||
"id": "fdd7ff16", | "id": "fdd7ff16", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"# T4. trainer 和 evaluator 的深入介绍(一)\n", | |||||
"# T4. trainer 和 evaluator 的深入介绍\n", | |||||
"\n", | "\n", | ||||
"  1   fastNLP 结合 pytorch 搭建模型\n", | |||||
"  1   fastNLP 中的更多 metric 类型\n", | |||||
"\n", | |||||
"    1.1   预定义的 metric 类型\n", | |||||
"\n", | |||||
"    1.2   自定义的 metric 类型\n", | |||||
"\n", | |||||
"  2   fastNLP 中 trainer 的补充介绍\n", | |||||
" \n", | " \n", | ||||
"    1.1   \n", | |||||
"    2.1   trainer 的提出构想 \n", | |||||
"\n", | "\n", | ||||
"    1.2   \n", | |||||
"    2.2   trainer 的内部结构\n", | |||||
"\n", | "\n", | ||||
"  2   fastNLP 中的 driver 与 device\n", | |||||
"    2.3   实例:\n", | |||||
"\n", | "\n", | ||||
"    2.1   \n", | |||||
"  3   fastNLP 中的 driver 与 device\n", | |||||
"\n", | "\n", | ||||
"    2.2   \n", | |||||
"    3.1   driver 的提出构想\n", | |||||
"\n", | |||||
"    3.2   device 与多卡训练" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "8d19220c", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 1. fastNLP 中的更多 metric 类型\n", | |||||
"\n", | "\n", | ||||
"  3   fastNLP 中 trainer 的补充介绍\n", | |||||
"### 1.1 预定义的 metric 类型\n", | |||||
"\n", | "\n", | ||||
"    3.1   \n", | |||||
"在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评价标准`metric`**\n", | |||||
"\n", | |||||
"  包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", | |||||
"\n", | |||||
"    **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括**召回率`Pre`**、**精确率`Rec`**\n", | |||||
"\n", | |||||
"    **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n", | |||||
"\n", | |||||
"| <div align=\"center\">代码名称</div> | <div align=\"center\">简要介绍</div> | <div align=\"center\">代码路径</div> |\n", | |||||
"|:--|:--|:--|\n", | |||||
"| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", | |||||
"| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", | |||||
"| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", | |||||
"| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", | |||||
"| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "fdc083a3", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"source": [ | |||||
"大概的描述一下,给出各个正确率的计算公式" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "9775ea5e", | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "8a22f522", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 2.2 自定义的 metric 类型\n", | |||||
"\n", | "\n", | ||||
"    3.2   " | |||||
"在`fastNLP 0.8`中,  给一个案例,训练部分留到trainer部分" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": null, | "execution_count": null, | ||||
"id": "d8caba1d", | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "4e6247dd", | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "08752c5a", | "id": "08752c5a", | ||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%% md\n" | |||||
} | |||||
}, | |||||
"source": [ | |||||
"## 2. fastNLP 中 trainer 的补充介绍\n", | |||||
"\n", | |||||
"### 2.1 trainer 的提出构想\n", | |||||
"\n", | |||||
"在`fastNLP 0.8`中,  " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "977a6355", | |||||
"metadata": {}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "69203cdc", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "ab1cea7d", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 2.2 trainer 的内部结构\n", | |||||
"\n", | |||||
"在`fastNLP 0.8`中,  \n", | |||||
"\n", | |||||
"'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", | |||||
"'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", | |||||
"'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", | |||||
"'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", | |||||
"'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", | |||||
"'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", | |||||
"'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", | |||||
"'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", | |||||
"'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", | |||||
"'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", | |||||
"'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", | |||||
"'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", | |||||
"'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", | |||||
"'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", | |||||
"'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", | |||||
"'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", | |||||
"'trainer_state', 'zero_grad'\n", | |||||
"\n", | |||||
"  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "b3c8342e", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "d28f2624", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "ce6322b4", | |||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | |||||
"### 2.3 实例:\n", | |||||
"\n", | |||||
"在`fastNLP 0.8`中,  " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "43be274f", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "c348864c", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "175d6ebb", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"## 3. fastNLP 中的 driver 与 device\n", | |||||
"\n", | |||||
"### 3.1 driver 的提出构想\n", | |||||
"\n", | |||||
"在`fastNLP 0.8`中,  " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "47100e7a", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "0204a223", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "markdown", | |||||
"id": "6e723b87", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"### 3.2 device 与多卡训练\n", | |||||
"\n", | |||||
"在`fastNLP 0.8`中,  " | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "5ad81ac7", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | |||||
"source": [] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"id": "cfb28b1b", | |||||
"metadata": { | |||||
"pycharm": { | |||||
"name": "#%%\n" | |||||
} | |||||
}, | |||||
"outputs": [], | "outputs": [], | ||||
"source": [] | "source": [] | ||||
} | } | ||||
@@ -52,6 +312,15 @@ | |||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.7.13" | "version": "3.7.13" | ||||
}, | |||||
"pycharm": { | |||||
"stem_cell": { | |||||
"cell_type": "raw", | |||||
"metadata": { | |||||
"collapsed": false | |||||
}, | |||||
"source": [] | |||||
} | |||||
} | } | ||||
}, | }, | ||||
"nbformat": 4, | "nbformat": 4, | ||||
@@ -5,21 +5,21 @@ | |||||
"id": "fdd7ff16", | "id": "fdd7ff16", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"# T6. trainer 和 evaluator 的深入介绍(二)\n", | |||||
"# T6. fastNLP 与 paddle 或 jittor 的结合\n", | |||||
"\n", | "\n", | ||||
"  1   fastNLP 中预定义模型 models\n", | |||||
"  1   fastNLP 结合 paddle 训练模型\n", | |||||
" \n", | " \n", | ||||
"    1.1   \n", | |||||
"    1.1   关于 paddle 的简单介绍\n", | |||||
"\n", | "\n", | ||||
"    1.2   \n", | |||||
"    1.2   使用 paddle 搭建并训练模型\n", | |||||
"\n", | "\n", | ||||
"  2   fastNLP 中预定义模型 modules\n", | |||||
" \n", | |||||
"    2.1   \n", | |||||
"  2   fastNLP 结合 jittor 训练模型\n", | |||||
"\n", | |||||
"    2.1   关于 jittor 的简单介绍\n", | |||||
"\n", | "\n", | ||||
"    2.2   \n", | |||||
"    2.2   使用 jittor 搭建并训练模型\n", | |||||
"\n", | "\n", | ||||
"  3   fastNLP 中的更多 metric 类型\n", | |||||
"  3   fastNLP 实现 paddle 与 pytorch 互转\n", | |||||
"\n", | "\n", | ||||
"    3.1   \n", | "    3.1   \n", | ||||
"\n", | "\n", | ||||
@@ -13,9 +13,9 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n", | |||||
"# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n", | |||||
"\n", | "\n", | ||||
"  1   基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n", | |||||
"  1   基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n", | |||||
"\n", | "\n", | ||||
"  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", | "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", | ||||
"\n", | "\n", | ||||
@@ -63,7 +63,7 @@ | |||||
"\n", | "\n", | ||||
"import fastNLP\n", | "import fastNLP\n", | ||||
"from fastNLP import Trainer\n", | "from fastNLP import Trainer\n", | ||||
"from fastNLP.core.metrics import Accuracy\n", | |||||
"from fastNLP import Accuracy\n", | |||||
"\n", | "\n", | ||||
"print(transformers.__version__)" | "print(transformers.__version__)" | ||||
] | ] | ||||
@@ -72,11 +72,11 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n", | |||||
"### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", | |||||
"\n", | "\n", | ||||
"  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n", | |||||
"  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", | |||||
"\n", | "\n", | ||||
"    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n", | |||||
"    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", | |||||
"\n", | "\n", | ||||
"**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", | "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", | ||||
"\n", | "\n", | ||||
@@ -92,7 +92,7 @@ | |||||
"\n", | "\n", | ||||
"  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", | "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", | ||||
"\n", | "\n", | ||||
"    此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图" | |||||
"    此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -116,9 +116,9 @@ | |||||
"\n", | "\n", | ||||
"  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", | "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", | ||||
"\n", | "\n", | ||||
"  数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", | |||||
"  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", | |||||
"\n", | "\n", | ||||
"对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n", | |||||
"对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", | |||||
"\n", | "\n", | ||||
"  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" | "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" | ||||
] | ] | ||||
@@ -134,14 +134,13 @@ | |||||
"name": "stderr", | "name": "stderr", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", | |||||
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"data": { | "data": { | ||||
"application/vnd.jupyter.widget-view+json": { | "application/vnd.jupyter.widget-view+json": { | ||||
"model_id": "adc9449171454f658285f220b70126e1", | |||||
"model_id": "c5915debacf9443986b5b3b34870b303", | |||||
"version_major": 2, | "version_major": 2, | ||||
"version_minor": 0 | "version_minor": 0 | ||||
}, | }, | ||||
@@ -163,7 +162,7 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"  加载之后,根据`GLUE`中`SST2`数据集的格式,尝试打印部分数据,检查加载结果" | |||||
"  加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -287,7 +286,7 @@ | |||||
"\n", | "\n", | ||||
"  其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", | "  其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", | ||||
"\n", | "\n", | ||||
"  例如,`'label'`是`SST2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", | |||||
"  例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", | |||||
"\n", | "\n", | ||||
"    `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" | "    `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" | ||||
] | ] | ||||
@@ -440,10 +439,10 @@ | |||||
"name": "stderr", | "name": "stderr", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n", | |||||
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n", | |||||
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||||
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||||
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']\n", | |||||
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", | |||||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||||
] | ] | ||||
} | } | ||||
@@ -472,7 +471,7 @@ | |||||
"trainer = Trainer(\n", | "trainer = Trainer(\n", | ||||
" model=model,\n", | " model=model,\n", | ||||
" driver='torch',\n", | " driver='torch',\n", | ||||
" device=1, # 'cuda'\n", | |||||
" device=0, # 'cuda'\n", | |||||
" n_epochs=10,\n", | " n_epochs=10,\n", | ||||
" optimizers=optimizers,\n", | " optimizers=optimizers,\n", | ||||
" train_dataloader=dataloader_train,\n", | " train_dataloader=dataloader_train,\n", | ||||
@@ -495,6 +494,33 @@ | |||||
"execution_count": 13, | "execution_count": 13, | ||||
"metadata": {}, | "metadata": {}, | ||||
"outputs": [ | "outputs": [ | ||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:12:45] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#592\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">592</span></a>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -505,6 +531,490 @@ | |||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "display_data" | "output_type": "display_data" | ||||
}, | }, | ||||
{ | |||||
"data": { | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">288.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8875</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">284.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.88125</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">282.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.875</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">280.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.865625</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">277.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.884375</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">283.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | |||||
"data": { | |||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"acc#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.878125</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"total#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320.0</span>,\n", | |||||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"correct#acc\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">281.0</span>\n", | |||||
"<span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | |||||
], | |||||
"text/plain": [ | |||||
"\u001b[1m{\u001b[0m\n", | |||||
" \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", | |||||
" \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", | |||||
" \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", | |||||
"\u001b[1m}\u001b[0m\n" | |||||
] | |||||
}, | |||||
"metadata": {}, | |||||
"output_type": "display_data" | |||||
}, | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
@@ -540,10 +1050,14 @@ | |||||
"outputs": [ | "outputs": [ | ||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||||
], | |||||
"text/plain": [] | |||||
"application/vnd.jupyter.widget-view+json": { | |||||
"model_id": "", | |||||
"version_major": 2, | |||||
"version_minor": 0 | |||||
}, | |||||
"text/plain": [ | |||||
"Output()" | |||||
] | |||||
}, | }, | ||||
"metadata": {}, | "metadata": {}, | ||||
"output_type": "display_data" | "output_type": "display_data" | ||||
@@ -561,7 +1075,7 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}" | |||||
"{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 14, | "execution_count": 14, | ||||
@@ -4,7 +4,7 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"# E2. 使用 Bert + prompt 完成 SST2 分类\n", | |||||
"# E2. 使用 Bert + prompt 完成 SST-2 分类\n", | |||||
"\n", | "\n", | ||||
"  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n", | "  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n", | ||||
"\n", | "\n", | ||||
@@ -19,7 +19,7 @@ | |||||
"source": [ | "source": [ | ||||
"### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n", | "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n", | ||||
"\n", | "\n", | ||||
"  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`prompt-based tuning`方式\n", | |||||
"  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`prompt-based tuning`方式\n", | |||||
"\n", | "\n", | ||||
"    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n", | "    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n", | ||||
"\n", | "\n", | ||||
@@ -27,41 +27,53 @@ | |||||
"\n", | "\n", | ||||
"**`prompt`**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的**`PET`模型**\n", | "**`prompt`**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的**`PET`模型**\n", | ||||
"\n", | "\n", | ||||
"    全称 **`Pattern-Exploiting Training`**,虽然文中并没有提到**`prompt`的说法,但仍视为其开山之作\n", | |||||
"    全称 **`Pattern-Exploiting Training`**,虽然文中并没有提到`prompt`的说法,但仍被视为开山之作\n", | |||||
"\n", | "\n", | ||||
"  其大致思路包括,对于文本分类任务,假定输入文本为,后来被称`prompt`,后来被称`verbalizer`,\n", | |||||
"  其大致思路包括,对于文本分类任务,假定输入文本为`\" X . \"`,设计**输入模板`template`**,**后来被称为`prompt`**\n", | |||||
"\n", | "\n", | ||||
"  其主要贡献在于,\n", | |||||
"    将输入重构为`\" X . It is [MASK] . \"`,**诱导或刺激语言模型在`[MASK]`位置生成含有情感倾向的词汇**\n", | |||||
"\n", | |||||
"    接着将该词汇**输入分类器中**,**后来被称为`verbalizer`**,从而得到该语句对应的情感倾向,实现文本分类\n", | |||||
"\n", | |||||
"  其主要贡献在于,通过构造`prompt`,诱导/刺激预训练模型生成期望适应下游任务特征,适合少样本学习的需求\n", | |||||
"\n", | "\n", | ||||
"<img src=\"./figures/E2-fig-pet-model.png\" width=\"36%\" height=\"36%\" align=\"center\"></img>\n", | "<img src=\"./figures/E2-fig-pet-model.png\" width=\"36%\" height=\"36%\" align=\"center\"></img>\n", | ||||
"\n", | "\n", | ||||
"**`prompt-based tuning`**,**基于提示的微调**,\n", | |||||
"**`prompt-based tuning`**,**基于提示的微调**,将`prompt`应用于**参数高效微调**,**`parameter-efficient tuning`**\n", | |||||
"\n", | |||||
"  通过**设计模板调整模型输入**或者**调整模型内部状态**,**固定预训练模型**,**诱导/刺激模型**调整输出以适应\n", | |||||
"\n", | "\n", | ||||
"  xxxx,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)\n", | |||||
"  当前任务,极大降低了训练开销,也省去了`verbalizer`的构造,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)、[DeltaTuning综述](https://arxiv.org/pdf/2203.06904.pdf)\n", | |||||
"\n", | "\n", | ||||
"    以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n", | "    以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n", | ||||
"\n", | "\n", | ||||
"  案例一:**`P-Tuning v1`**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n", | |||||
"  **案例一**:**`PrefixTuning`**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n", | |||||
"\n", | |||||
"    其主要贡献在于,**提出连续的、非人工构造的、任务导向的`prompt`**,即**前缀`prefix`**,**调整**\n", | |||||
"\n", | |||||
"      **模型内部更新状态**,诱导模型在特定任务下生成期望目标,降低优化难度,提升微调效果\n", | |||||
"\n", | |||||
"    其主要研究对象,是`GPT2`和`BART`,主要面向生成任务`NLG`,如`table-to-text`和摘要\n", | |||||
"\n", | "\n", | ||||
"    其主要贡献在于,\n", | |||||
"  **案例二**:**`P-Tuning v1`**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n", | |||||
"\n", | "\n", | ||||
"    其方法大致包括,\n", | |||||
"    其主要贡献在于,**通过连续的、非人工构造的`prompt`调整模型输入**,取代原先基于单词设计的\n", | |||||
"\n", | "\n", | ||||
"  案例二:**`PromptTuning`**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n", | |||||
"      但离散且不易于优化的`prompt`;同时也**证明了`GPT2`在语言理解任务上仍然是可以胜任的**\n", | |||||
"\n", | "\n", | ||||
"    其主要贡献在于,\n", | |||||
"    其主要研究对象,是`GPT2`,主要面向知识探测`knowledge probing`和自然语言理解`NLU`\n", | |||||
"\n", | "\n", | ||||
"    其方法大致包括,\n", | |||||
"  **案例三**:**`PromptTuning`**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n", | |||||
"\n", | "\n", | ||||
"  案例三:**`PrefixTuning`**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n", | |||||
"    其主要贡献在于,通过连续的`prompt`调整模型输入,**证明了`prompt-based tuning`的效果**\n", | |||||
"\n", | "\n", | ||||
"    其主要贡献在于,\n", | |||||
"      **随模型参数量的增加而提升**,最终**在`10B`左右追上了全参数微调`fine-tuning`的效果**\n", | |||||
"\n", | "\n", | ||||
"    其方法大致包括,\n", | |||||
"    其主要面向自然语言理解`NLU`,通过为每个任务定义不同的`prompt`,从而支持多任务语境\n", | |||||
"\n", | "\n", | ||||
"通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n", | "通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n", | ||||
"\n", | "\n", | ||||
"  目前,加载预训练模型的主流方法是使用`transformers`模块,而实现微调的框架则\n", | |||||
"  目前,加载预训练模型的主流方法是使用**`transformers`模块**,而实现微调的框架则\n", | |||||
"\n", | "\n", | ||||
"    可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n", | "    可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n", | ||||
"\n", | "\n", | ||||
@@ -69,9 +81,9 @@ | |||||
"\n", | "\n", | ||||
"    **和`transformers`模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n", | "    **和`transformers`模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n", | ||||
"\n", | "\n", | ||||
"本示例仍使用了`tutorial-E1`的`SST2`数据集、`distilbert-base-uncased`模型(便于比较\n", | |||||
"本示例仍使用了`tutorial-E1`的`SST-2`数据集、`distilbert-base-uncased`模型(便于比较\n", | |||||
"\n", | "\n", | ||||
"  使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST2`二分类任务" | |||||
"  使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST-2`二分类任务" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -246,7 +258,7 @@ | |||||
"source": [ | "source": [ | ||||
"### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", | "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", | ||||
"\n", | "\n", | ||||
"  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST2`数据集\n", | |||||
"  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n", | |||||
"\n", | "\n", | ||||
"    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", | "    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", | ||||
"\n", | "\n", | ||||