{ "cells": [ { "cell_type": "markdown", "id": "fdd7ff16", "metadata": {}, "source": [ "# T5. trainer 和 evaluator 的深入介绍\n", "\n", " 1 fastNLP 中 driver 的补充介绍\n", " \n", " 1.1 trainer 和 driver 的构想 \n", "\n", " 1.2 device 与 多卡训练\n", "\n", " 2 fastNLP 中的更多 metric 类型\n", "\n", " 2.1 预定义的 metric 类型\n", "\n", " 2.2 自定义的 metric 类型\n", "\n", " 3 fastNLP 中 trainer 的补充介绍\n", "\n", " 3.1 trainer 的内部结构" ] }, { "cell_type": "markdown", "id": "08752c5a", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 1. fastNLP 中 driver 的补充介绍\n", "\n", "### 1.1 trainer 和 driver 的构想\n", "\n", "在`fastNLP 0.8`中,模型训练最关键的模块便是**训练模块`trainer`、评测模块`evaluator`、驱动模块`driver`**,\n", "\n", " 在`tutorial 0`中,已经简单介绍过上述三个模块:**`driver`用来控制训练评测中的`model`的最终运行**\n", "\n", " **`evaluator`封装评测的`metric`**,**`trainer`封装训练的`optimizer`**,**也可以包括`evaluator`**\n", "\n", "之所以做出上述的划分,其根本目的在于要**达成对于多个`python`学习框架**,**例如`pytorch`、`paddle`、`jittor`的兼容**\n", "\n", " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", "\n", " 划分为**框架无关的循环控制、批量分发部分**,**由`trainer`模块负责**实现,对应的伪代码如下方中间蓝色一栏所示\n", "\n", " 以及**随框架不同的模型调用、数值优化部分**,**由`driver`模块负责**实现,对应的伪代码如下方右边红色一栏所示\n", "\n", "|
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import Metric\n",
"\n",
"class MyMetric(Metric):\n",
"\n",
" def __init__(self):\n",
" Metric.__init__(self)\n",
" self.total_num = 0\n",
" self.right_num = 0\n",
"\n",
" def update(self, pred, target):\n",
" self.total_num += target.size(0)\n",
" self.right_num += target.eq(pred).sum().item()\n",
"\n",
" def get_metric(self, reset=True):\n",
" acc = self.right_num / self.total_num\n",
" if reset:\n",
" self.total_num = 0\n",
" self.right_num = 0\n",
" return {'prefix': acc}"
]
},
{
"cell_type": "markdown",
"id": "0155f447",
"metadata": {},
"source": [
" 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "5ad81ac7",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ef923b90b19847f4916cccda5d33fc36",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"sst2data = load_dataset('glue', 'sst2')"
]
},
{
"cell_type": "markdown",
"id": "e9d81760",
"metadata": {},
"source": [
" 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n",
"\n",
" 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfb28b1b",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/6000 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from fastNLP import DataSet\n",
"\n",
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
"\n",
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n",
"dataset.delete_field('sentence')\n",
"dataset.delete_field('idx')\n",
"\n",
"from fastNLP import Vocabulary\n",
"\n",
"vocab = Vocabulary()\n",
"vocab.from_dataset(dataset, field_name='words')\n",
"vocab.index_dataset(dataset, field_name='words')\n",
"\n",
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
"\n",
"from fastNLP import prepare_torch_dataloader\n",
"\n",
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
]
},
{
"cell_type": "markdown",
"id": "af3f8c63",
"metadata": {},
"source": [
" 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2fd210c5",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models.torch import CNNText\n",
"\n",
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{
"cell_type": "markdown",
"id": "6e723b87",
"metadata": {},
"source": [
"## 3. fastNLP 中 trainer 的补充介绍\n",
"\n",
"### 3.1 trainer 的内部结构\n",
"\n",
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
"\n",
" 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
"\n",
"| [09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", "\n" ], "text/plain": [ "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
"\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
"\n"
],
"text/plain": [
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
".get_parent()\n",
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.6875\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8125\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.825\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8125\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.8\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", "\n" ], "text/plain": [ "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{\n", " \"prefix#suffix\": 0.80625\n", "}\n", "\n" ], "text/plain": [ "\u001b[1m{\u001b[0m\n", " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", "\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1abfa0a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}