diff --git a/tutorials/fastnlp_tutorial_2.ipynb b/tutorials/fastnlp_tutorial_2.ipynb index 64c4bc8b..33f74b8b 100644 --- a/tutorials/fastnlp_tutorial_2.ipynb +++ b/tutorials/fastnlp_tutorial_2.ipynb @@ -801,24 +801,24 @@ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "\n", "# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n", - "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv'))\n", + "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n", "train_ds, test_ds = datasets.split(ratio=0.7)\n", "data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n", "\n", "# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n", "encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n", " return_attention_mask=True)\n", - "data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n", + "data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n", "\n", "# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n", "target_vocab = Vocabulary(padding=None, unknown=None)\n", - "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n", - "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n", + "target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n", + "target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n", " new_field_name='target')\n", "\n", "# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n", "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", - "data_bundle.set_ignore('label', 'text') \n", + "data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence') \n", "```" ] }, diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 172e1232..7566c02a 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -9,9 +9,9 @@ "\n", " 1 fastNLP 中的 dataloader\n", " \n", - " 1.1 dataloader 的职责描述\n", + " 1.1 dataloader 的基本介绍\n", "\n", - " 1.2 dataloader 的基本使用\n", + " 1.2 dataloader 的函数创建\n", "\n", " 2 fastNLP 中 dataloader 的延伸\n", "\n", @@ -27,32 +27,143 @@ "source": [ "## 1. fastNLP 中的 dataloader\n", "\n", - "### 1.1 dataloader 的职责描述\n", + "### 1.1 dataloader 的基本介绍\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前" + "在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n", + "\n", + " **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n", + "\n", + "本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n", + "\n", + " **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n", + "\n", + " 针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n", + "\n", + " 而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n", + "\n", + " 因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n", + "\n", + "只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n", + "\n", + " **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n", + "\n", + " 因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n", + "\n", + " 针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n", + "\n", + " 就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n", + "\n", + "**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n", + "\n", + " 第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n", + "\n", + " 第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n", + "\n", + " 第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n", + "\n", + " 对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n", + "\n", + " 分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n", + "\n", + "|
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/6000 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -404,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "b9dd1273", "metadata": {}, "outputs": [], @@ -429,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "f6e76e2e", "metadata": {}, "outputs": [], @@ -453,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "50a13ee5", "metadata": {}, "outputs": [], @@ -474,600 +423,20 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "28903a7d", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.575,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 92.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.75625,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 121.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.78125,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 125.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 128.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.79375,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 127.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.80625,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 129.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.825,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 132.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.81875,\n", - " \"total#acc\": 160.0,\n", - " \"correct#acc\": 131.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.run()" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "f47a6a35", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trainer.evaluator.run()" ] @@ -1082,21 +451,10 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "c1a2e2ca", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "342" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import gc\n", "\n", @@ -1144,32 +502,10 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "03e66686", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3ec9e0ce9a054339a2453420c2c9f28b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", @@ -1196,25 +532,10 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "1f88cad4", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing: 0%| | 0/4000 [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -1252,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "7802a072", "metadata": {}, "outputs": [], @@ -1277,7 +598,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "4e12c09f", "metadata": {}, "outputs": [], @@ -1306,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "cbd6c205", "metadata": {}, "outputs": [], @@ -1327,600 +648,20 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "0f8eff34", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
[17:49:16] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.220374,\n", - " \"pre#F1\": 0.25,\n", - " \"rec#F1\": 0.197026\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.442857,\n", - " \"pre#F1\": 0.426117,\n", - " \"rec#F1\": 0.460967\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.572954,\n", - " \"pre#F1\": 0.549488,\n", - " \"rec#F1\": 0.598513\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.665399,\n", - " \"pre#F1\": 0.680934,\n", - " \"rec#F1\": 0.650558\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.734694,\n", - " \"pre#F1\": 0.733333,\n", - " \"rec#F1\": 0.736059\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.742647,\n", - " \"pre#F1\": 0.734545,\n", - " \"rec#F1\": 0.750929\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.773585,\n", - " \"pre#F1\": 0.785441,\n", - " \"rec#F1\": 0.762082\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.770115,\n", - " \"pre#F1\": 0.794466,\n", - " \"rec#F1\": 0.747212\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.7603,\n", - " \"pre#F1\": 0.766038,\n", - " \"rec#F1\": 0.754647\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"f#F1\": 0.743682,\n", - " \"pre#F1\": 0.722807,\n", - " \"rec#F1\": 0.765799\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", - " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", - " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "37871d6b", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trainer.evaluator.run()" ] diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb index 0669a60a..63410113 100644 --- a/tutorials/fastnlp_tutorial_5.ipynb +++ b/tutorials/fastnlp_tutorial_5.ipynb @@ -312,6 +312,9 @@ "metadata": {}, "outputs": [], "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", "from fastNLP import Metric\n", "\n", "class MyMetric(Metric):\n", @@ -333,33 +336,6 @@ " return {'prefix': acc}" ] }, - { - "cell_type": "markdown", - "id": "af3f8c63", - "metadata": {}, - "source": [ - " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2fd210c5", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append('..')\n", - "\n", - "from fastNLP.models.torch import CNNText\n", - "\n", - "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", - "\n", - "from torch.optim import AdamW\n", - "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-4)" - ] - }, { "cell_type": "markdown", "id": "0155f447", @@ -389,9 +365,9 @@ "id": "e9d81760", "metadata": {}, "source": [ - "接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n", + " 在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n", "\n", - " 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`" + " 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`" ] }, { @@ -429,14 +405,136 @@ "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" ] }, + { + "cell_type": "markdown", + "id": "af3f8c63", + "metadata": {}, + "source": [ + " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2fd210c5", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "6e723b87", + "metadata": {}, + "source": [ + "## 3. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 3.1 trainer 的内部结构\n", + "\n", + "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n", + "\n", + " 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n", + "\n", + "|