From 8319706f0222877ee29a0c5f017979939d3aa6b3 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Sat, 4 Jun 2022 21:15:44 +0800 Subject: [PATCH] finish tutorial-3456 lxr 220604 --- tutorials/fastnlp_tutorial_3.ipynb | 374 +++-- tutorials/fastnlp_tutorial_4.ipynb | 2285 +++++++++++++++++++++++++--- tutorials/fastnlp_tutorial_5.ipynb | 659 +++++++- tutorials/fastnlp_tutorial_6.ipynb | 1443 +++++++++++++++++- tutorials/fastnlp_tutorial_7.ipynb | 59 - tutorials/fastnlp_tutorial_8.ipynb | 59 - 6 files changed, 4377 insertions(+), 502 deletions(-) delete mode 100644 tutorials/fastnlp_tutorial_7.ipynb delete mode 100644 tutorials/fastnlp_tutorial_8.ipynb diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 7566c02a..ff8151b9 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -17,7 +17,7 @@ "\n", "    2.1   collator 的概念与使用\n", "\n", - "    2.2   sampler 的概念与使用" + "    2.2   结合 datasets 框架" ] }, { @@ -71,8 +71,8 @@ "| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", "| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", "| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", - "| `sampler` | √ | √ | ? | 默认`None` |\n", - "| `batch_sampler` | √ | √ | ? | 默认`None` |\n", + "| `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + "| `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", "| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", "| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", "| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", @@ -95,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "aca72b49", "metadata": { "pycharm": { @@ -103,6 +103,26 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "application/vnd.jupyter.widget-view+json": { @@ -149,14 +169,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n", - "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n", - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n", - "| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", - "+------------+------------------+-----------+------------------+--------------------+--------------------+\n" + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n", + "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n" ] } ], @@ -200,7 +220,9 @@ " \n", "pipe = PipeDemo(tokenizer='bert-base-uncased')\n", "\n", - "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')" + "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n", + "\n", + "print(data_bundle.get_dataset('train'))" ] }, { @@ -214,15 +236,65 @@ "\n", "  例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n", "\n", - "  类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`" + "  类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n", + "\n", + "同时我们看还可以发现,在`fastNLP 0.8`中,**`batch`表示为字典`dict`类型**,**`key`值就是原先数据集中各个字段**\n", + "\n", + "  **除去经过`DataBundle.set_ignore`函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "5fd60e42", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n", + "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'target': tensor([0, 1, 1, 2]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] + } + ], "source": [ "from fastNLP import prepare_torch_dataloader\n", "\n", @@ -230,28 +302,15 @@ "evaluate_dataset = data_bundle.get_dataset('dev')\n", "\n", "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" - ] - }, - { - "cell_type": "markdown", - "id": "7c53f181", - "metadata": {}, - "source": [ - "```python\n", - "trainer = Trainer(\n", - " model=model,\n", - " train_dataloader=train_dataloader,\n", - " optimizers=optimizer,\n", - "\t...\n", - "\tdriver='torch',\n", - "\tdevice='cuda',\n", - "\t...\n", - " evaluate_dataloaders=evaluate_dataloader, \n", - " metrics={'acc': Accuracy()},\n", - "\t...\n", - ")\n", - "```" + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "print(type(train_dataloader))\n", + "\n", + "import pprint\n", + "\n", + "for batch in train_dataloader:\n", + " print(type(batch), type(batch['input_ids']), list(batch))\n", + " pprint.pprint(batch, width=1)" ] }, { @@ -259,27 +318,33 @@ "id": "9f457a6e", "metadata": {}, "source": [ - "之所以称`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n", + "之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n", "\n", "  **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n", "\n", - "  例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n", - "\n", - "  在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", - "\n", - "    这里也可以看出 **`evaluate_dataloaders`的妙处**,一次评测可以针对多个数据集" + "例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "id": "7827557d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "from fastNLP import prepare_paddle_dataloader\n", "\n", - "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" + "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n", + "\n", + "print(type(dl_bundle['train']))" ] }, { @@ -287,6 +352,10 @@ "id": "d898cf40", "metadata": {}, "source": [ + "  而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", + "\n", + "  这里也可以看出`trainer`模块中,**`evaluate_dataloaders`的设计允许评测可以针对多个数据集**\n", + "\n", "```python\n", "trainer = Trainer(\n", " model=model,\n", @@ -312,31 +381,45 @@ "\n", "### 2.1 collator 的概念与使用\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n", + "在`fastNLP 0.8`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n", "\n", - "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n", + "  例如,**实现序列的补零对齐的核对器`collator`模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n", "\n", - "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n", + "在`fastNLP 0.8`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n", "\n", - "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过" + "|
名称
|
属性
|
方法
|
功能
|
内容
|\n", + "|:--|:--:|:--:|:--|:--|\n", + "| `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n", + "| `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐  | 字典类型 |\n", + "| `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n", + "| `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n", + "| `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n", + "| `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name`  |\n", + "| `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |" ] }, { "cell_type": "code", - "execution_count": null, - "id": "651baef6", + "execution_count": 4, + "id": "d0795b3e", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ - "from fastNLP import prepare_torch_dataloader\n", + "train_dataloader.collate_fn\n", "\n", - "dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n", - "\n", - "print(type(dl_bundle), type(dl_bundle['train']))" + "print(type(train_dataloader.collate_fn))" ] }, { @@ -344,80 +427,165 @@ "id": "5f816ef5", "metadata": {}, "source": [ - "  " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "726ba357", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n", - "print(type(dataloader))\n", - "print(dir(dataloader))" + "此外,还可以**手动定义`dataloader`中的`collate_fn`**,而不是使用`fastNLP 0.8`中自带的`collator`模块\n", + "\n", + "  该函数的定义可以大致如下,需要注意的是,**定义`collate_fn`之前需要了解`batch`作为字典的格式**\n", + "\n", + "  该函数通过`collate_fn`参数传入`dataloader`,**在`batch`分发**(**而不是`batch`划分**)**时调用**" ] }, { "cell_type": "code", - "execution_count": null, - "id": "d0795b3e", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": 5, + "id": "ff8e405e", + "metadata": {}, "outputs": [], "source": [ - "dataloader.collate_fn" + "import torch\n", + "\n", + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item['input_ids'])\n", + " max_length[0] = max(len(each_item['input_ids']), max_length[0])\n", + " atten_mask.append(each_item['token_type_ids'])\n", + " max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n", + " labels.append(each_item['attention_mask'])\n", + " max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}" ] }, { "cell_type": "markdown", - "id": "f9bbd9a7", + "id": "487b75fb", "metadata": {}, "source": [ - "### 2.2 sampler 的概念与使用" + "注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型" ] }, { "cell_type": "code", - "execution_count": null, - "id": "b0c3c58d", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 6, + "id": "e916d1ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] } - }, - "outputs": [], + ], "source": [ - "dataloader.batch_sampler" + "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n", + "\n", + "print(type(train_dataloader))\n", + "print(type(train_dataloader.collate_fn))\n", + "\n", + "for batch in train_dataloader:\n", + " pprint.pprint(batch, width=1)" ] }, { "cell_type": "markdown", - "id": "51bf0878", + "id": "0bd98365", "metadata": {}, "source": [ - "  " + "### 2.2 fastNLP 与 datasets 的结合\n", + "\n", + "从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v0.8`数据读取、预处理、加载,整个流程的介绍\n", + "\n", + "  不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n", + "\n", + "**使用`datasets`模块中的`load_dataset`函数**,通过指定数据集两级的名称,示例中即是**`GLUE`标准中的`SST-2`数据集**\n", + "\n", + "  即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n", + "\n", + "  之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了" ] }, { "cell_type": "code", - "execution_count": null, - "id": "3fd2486f", - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 7, + "id": "91879c30", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.525,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 84.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.54375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 87.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.55,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 88.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 100.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.65,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 104.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.69375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 111.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.66875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.68125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 109.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# trainer.run(num_eval_batch_per_dl=10)" + "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "8bc4bfb2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
-    "# trainer.evaluator.run()"
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "07538876",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "1b52eafd",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "383"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "del dataset\n",
+    "del sst2data\n",
+    "\n",
+    "gc.collect()"
    ]
   },
   {
@@ -255,6 +1050,8 @@
     "\n",
     "  本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
     "\n",
+    "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n",
+    "\n",
     "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n",
     "\n",
     "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
@@ -281,95 +1078,7 @@
     ")\n",
     "```\n",
     "\n",
-    "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n",
-    "\n",
-    "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "1aa5cf6d",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from datasets import load_dataset\n",
-    "\n",
-    "sst2data = load_dataset('glue', 'sst2')"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "c476abe7",
-   "metadata": {},
-   "source": [
-    "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n",
-    "\n",
-    "  **使用`apply_more`函数、`Vocabulary`模块的`from_/index_dataset`函数预处理数据**\n",
-    "\n",
-    "    并结合`delete_field`函数删除字段调整格式,`split`函数划分测试集和验证集\n",
-    "\n",
-    "  **仅保留`'words'`字段表示输入文本单词序号序列、`'target'`字段表示文本对应预测输出结果**\n",
-    "\n",
-    "    两者**对应到`CNNText`中`train_step`函数和`evaluate_step`函数的签名/输入参数**"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "357ea748",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import sys\n",
-    "sys.path.append('..')\n",
-    "\n",
-    "from fastNLP import DataSet\n",
-    "\n",
-    "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
-    "\n",
-    "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
-    "                   progress_bar=\"tqdm\")\n",
-    "dataset.delete_field('sentence')\n",
-    "dataset.delete_field('label')\n",
-    "dataset.delete_field('idx')\n",
-    "\n",
-    "from fastNLP import Vocabulary\n",
-    "\n",
-    "vocab = Vocabulary()\n",
-    "vocab.from_dataset(dataset, field_name='words')\n",
-    "vocab.index_dataset(dataset, field_name='words')\n",
-    "\n",
-    "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "96380c67",
-   "metadata": {},
-   "source": [
-    "然后,使用`tutorial-3`中的知识,**通过`prepare_torch_dataloader`处理数据集得到`dataloader`**"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "b9dd1273",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from fastNLP import prepare_torch_dataloader\n",
-    "\n",
-    "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
-    "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "96941b63",
-   "metadata": {},
-   "source": [
-    "接着,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`实例以及`optimizer`实例\n",
+    "对应到代码上,**从`fastNLP.models.torch`路径下导入`CNNText`**,初始化`CNNText`和`optimizer`实例\n",
     "\n",
     "  注意:初始化`CNNText`时,**二元组参数`embed`、分类数量`num_classes`是必须传入的**,其中\n",
     "\n",
@@ -378,7 +1087,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 10,
    "id": "f6e76e2e",
    "metadata": {},
    "outputs": [],
@@ -397,12 +1106,12 @@
    "id": "0cc5ca10",
    "metadata": {},
    "source": [
-    "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
+    "  最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 11,
    "id": "50a13ee5",
    "metadata": {},
    "outputs": [],
@@ -423,45 +1132,634 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "id": "28903a7d",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[16:21:57] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.654444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 589.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.767778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 691.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.797778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 718.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.803333,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 723.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.807778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 727.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.812222,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 731.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.804444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 724.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.806667,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 726.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "f47a6a35", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trainer.evaluator.run()"
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "7c811257",
+   "id": "5b5c0446",
    "metadata": {},
    "source": [
-    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "id": "c1a2e2ca",
+   "execution_count": 14,
+   "id": "e9e70f88",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "344"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "import gc\n",
     "\n",
     "del model\n",
     "del trainer\n",
-    "del dataset\n",
-    "del sst2data\n",
     "\n",
     "gc.collect()"
    ]
@@ -502,10 +1800,32 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "id": "03e66686",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "593bc03ed5914953ab94268ff2f01710",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/3 [00:00[16:23:41] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.169014,\n",
+       "  \"pre#F1\": 0.170732,\n",
+       "  \"rec#F1\": 0.167331\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.361809,\n",
+       "  \"pre#F1\": 0.312139,\n",
+       "  \"rec#F1\": 0.430279\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.525,\n",
+       "  \"pre#F1\": 0.475728,\n",
+       "  \"rec#F1\": 0.585657\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.627306,\n",
+       "  \"pre#F1\": 0.584192,\n",
+       "  \"rec#F1\": 0.677291\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.710937,\n",
+       "  \"pre#F1\": 0.697318,\n",
+       "  \"rec#F1\": 0.7251\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.739563,\n",
+       "  \"pre#F1\": 0.738095,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.748491,\n",
+       "  \"pre#F1\": 0.756098,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.716763,\n",
+       "  \"pre#F1\": 0.69403,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.768293,\n",
+       "  \"pre#F1\": 0.784232,\n",
+       "  \"rec#F1\": 0.752988\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.757692,\n",
+       "  \"pre#F1\": 0.732342,\n",
+       "  \"rec#F1\": 0.784861\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "37871d6b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trainer.evaluator.run()"
    ]
diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb
index 63410113..3f2bbfa6 100644
--- a/tutorials/fastnlp_tutorial_5.ipynb
+++ b/tutorials/fastnlp_tutorial_5.ipynb
@@ -296,7 +296,7 @@
     "\n",
     "    在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
     "\n",
-    "    此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n",
+    "    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
     "\n",
     "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
     "\n",
@@ -307,10 +307,24 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "id": "08a872e9",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -320,16 +334,16 @@ "class MyMetric(Metric):\n", "\n", " def __init__(self):\n", - " MyMetric.__init__(self)\n", + " Metric.__init__(self)\n", " self.total_num = 0\n", " self.right_num = 0\n", "\n", - " def update(self, pred, true):\n", + " def update(self, pred, target):\n", " self.total_num += target.size(0)\n", " self.right_num += target.eq(pred).sum().item()\n", "\n", " def get_metric(self, reset=True):\n", - " acc = self.acc_count / self.total_num\n", + " acc = self.right_num / self.total_num\n", " if reset:\n", " self.total_num = 0\n", " self.right_num = 0\n", @@ -346,14 +360,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "5ad81ac7", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00名称 |
参数
|
属性
|
功能
|
内容
|\n", "|:--|:--:|:--:|:--|:--|\n", "| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n", - "| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", - "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n", "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n", + "| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", + "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n", "| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n", "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n", @@ -473,12 +522,34 @@ "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |" ] }, + { + "cell_type": "markdown", + "id": "9e13ee08", + "metadata": {}, + "source": [ + "其中,**`input_mapping`和`output_mapping`** 定义形式如下:输入字典形式的数据,根据参数匹配要求\n", + "\n", + "  调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "de96c1d1", + "metadata": {}, + "outputs": [], + "source": [ + "def input_mapping(data):\n", + " data['target'] = data['label']\n", + " return data" + ] + }, { "cell_type": "markdown", "id": "2fc8b9f3", "metadata": {}, "source": [ - "  以及`trainer`模块内部的基础方法,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n", + "  而`trainer`模块的基础方法列表如下,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n", "\n", "|
名称
|
功能
|
主要参数
|\n", "|:--|:--|:--|\n", @@ -539,7 +610,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "926a9c50", "metadata": {}, "outputs": [], @@ -552,6 +623,7 @@ " device=0, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", + " input_mapping=input_mapping,\n", " train_dataloader=train_dataloader,\n", " evaluate_dataloaders=evaluate_dataloader,\n", " metrics={'suffix': MyMetric()}\n", @@ -580,14 +652,557 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "43be274f", "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[09:30:35] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.6875\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.825\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] diff --git a/tutorials/fastnlp_tutorial_6.ipynb b/tutorials/fastnlp_tutorial_6.ipynb index 552f73d9..63f7481e 100644 --- a/tutorials/fastnlp_tutorial_6.ipynb +++ b/tutorials/fastnlp_tutorial_6.ipynb @@ -19,15 +19,37 @@ "\n", "    2.2   使用 jittor 搭建并训练模型\n", "\n", - "  3   fastNLP 实现 paddle 与 pytorch 互转" + "" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "08752c5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00 True\n" + ] + } + ], "source": [ "import sys\n", "sys.path.append('..')\n", @@ -74,72 +138,108 @@ "metadata": {}, "source": [ "## 1. fastNLP 结合 paddle 训练模型\n", - "\n", - "```python\n", - "import paddle\n", - "\n", - "lstm = paddle.nn.LSTM(16, 32, 2)\n", - "\n", - "x = paddle.randn((4, 23, 16))\n", - "h = paddle.randn((2, 4, 32))\n", - "c = paddle.randn((2, 4, 32))\n", - "\n", - "y, (h, c) = lstm(x, (h, c))\n", - "\n", - "print(y.shape) # [4, 23, 32]\n", - "print(h.shape) # [2, 4, 32]\n", - "print(c.shape) # [2, 4, 32]\n", - "```" + "\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "e31b3198", "metadata": {}, "outputs": [], "source": [ "import paddle\n", "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", "\n", "\n", "class ClsByPaddle(nn.Layer):\n", - " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, dropout=0.5):\n", " nn.Layer.__init__(self)\n", " self.hidden_dim = hidden_dim\n", "\n", " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n", - " # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", - " # num_layers=num_layers, direction='bidirectional', dropout=dropout)\n", - " self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),\n", + " \n", + " self.conv1 = nn.Sequential(nn.Conv1D(embedding_dim, 30, 1, padding=0), nn.ReLU())\n", + " self.conv2 = nn.Sequential(nn.Conv1D(embedding_dim, 40, 3, padding=1), nn.ReLU())\n", + " self.conv3 = nn.Sequential(nn.Conv1D(embedding_dim, 50, 5, padding=2), nn.ReLU())\n", + "\n", + " self.mlp = nn.Sequential(('dropout', nn.Dropout(p=dropout)),\n", + " ('linear_1', nn.Linear(120, hidden_dim)),\n", " ('activate', nn.ReLU()),\n", - " ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))\n", + " ('linear_2', nn.Linear(hidden_dim, output_dim)))\n", " \n", - " self.loss_fn = nn.CrossEntropyLoss()\n", + " self.loss_fn = nn.MSELoss()\n", "\n", " def forward(self, words):\n", - " output = self.embedding(words)\n", - " # output, (hidden, cell) = self.lstm(output)\n", - " hidden = paddle.randn((2, words.shape[0], self.hidden_dim))\n", - " output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))\n", + " output = self.embedding(words).transpose([0, 2, 1])\n", + " conv1, conv2, conv3 = self.conv1(output), self.conv2(output), self.conv3(output)\n", + "\n", + " pool1 = F.max_pool1d(conv1, conv1.shape[-1]).squeeze(2)\n", + " pool2 = F.max_pool1d(conv2, conv2.shape[-1]).squeeze(2)\n", + " pool3 = F.max_pool1d(conv3, conv3.shape[-1]).squeeze(2)\n", + "\n", + " pool = paddle.concat([pool1, pool2, pool3], axis=1)\n", + " output = self.mlp(pool)\n", " return output\n", " \n", " def train_step(self, words, target):\n", " pred = self(words)\n", - " return {\"loss\": self.loss_fn(pred, target)}\n", + " target = paddle.stack((1 - target, target), axis=1).cast(pred.dtype)\n", + " return {'loss': self.loss_fn(pred, target)}\n", "\n", " def evaluate_step(self, words, target):\n", " pred = self(words)\n", - " pred = paddle.max(pred, axis=-1)[1]\n", - " return {\"pred\": pred, \"target\": target}" + " pred = paddle.argmax(pred, axis=-1)\n", + " return {'pred': pred, 'target': target}" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "c63b030f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0604 21:02:25.453869 19014 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 11.1, Runtime API Version: 10.2\n", + "W0604 21:02:26.061690 19014 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.\n" + ] + }, + { + "data": { + "text/plain": [ + "ClsByPaddle(\n", + " (embedding): Embedding(8458, 100, sparse=False)\n", + " (conv1): Sequential(\n", + " (0): Conv1D(100, 30, kernel_size=[1], data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv1D(100, 40, kernel_size=[3], padding=1, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv1D(100, 50, kernel_size=[5], padding=2, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (mlp): Sequential(\n", + " (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)\n", + " (linear_1): Linear(in_features=120, out_features=64, dtype=float32)\n", + " (activate): ReLU()\n", + " (linear_2): Linear(in_features=64, out_features=2, dtype=float32)\n", + " )\n", + " (loss_fn): MSELoss()\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", @@ -148,34 +248,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "2997c0aa", "metadata": {}, "outputs": [], "source": [ "from paddle.optimizer import AdamW\n", "\n", - "optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)" + "optimizers = AdamW(parameters=model.parameters(), learning_rate=5e-4)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "ead35fb8", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_paddle_dataloader\n", "\n", - "# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", + "train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", "\n", - "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" + "# dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "25e8da83", "metadata": {}, "outputs": [], @@ -185,23 +285,631 @@ "trainer = Trainer(\n", " model=model,\n", " driver='paddle',\n", - " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", + " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", - " train_dataloader=dl_bundle['train'], # train_dataloader,\n", - " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'], \n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "d63c5d74", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[21:03:08] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n",
+       "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n",
+       "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n",
+       "safe. \n",
+       "Deprecated in NumPy 1.20; for more details and guidance: \n",
+       "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+       "  if data.dtype == np.object:\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.78125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.80625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 129.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "trainer.run(num_eval_batch_per_dl=10) # 然后卡了?" + "trainer.run(num_eval_batch_per_dl=10) " ] }, { @@ -214,7 +922,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "c600191d", "metadata": {}, "outputs": [], @@ -231,37 +939,61 @@ " self.hidden_dim = hidden_dim\n", "\n", " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", - " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", - " self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", " nn.ReLU(),\n", - " nn.Linear(hidden_dim * 2, output_dim)])\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", "\n", - " self.loss_fn = nn.BCELoss()\n", + " self.loss_fn = nn.MSELoss()\n", "\n", " def execute(self, words):\n", " output = self.embedding(words)\n", " output, (hidden, cell) = self.lstm(output)\n", - " # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n", - " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", " return output\n", " \n", " def train_step(self, words, target):\n", " pred = self(words)\n", - " return {\"loss\": self.loss_fn(pred, target)}\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", "\n", " def evaluate_step(self, words, target):\n", " pred = self(words)\n", - " pred = jittor.max(pred, axis=-1)[1]\n", - " return {\"pred\": pred, \"target\": target}" + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "a94ed8c4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", "\n", @@ -270,34 +1002,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "6d15ebc1", "metadata": {}, "outputs": [], "source": [ "from jittor.optim import AdamW\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "95d8d09e", "metadata": {}, "outputs": [], "source": [ "from fastNLP import prepare_jittor_dataloader\n", "\n", - "# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", - "# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", "\n", - "dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "917eab81", "metadata": {}, "outputs": [], @@ -307,24 +1039,587 @@ "trainer = Trainer(\n", " model=model,\n", " driver='jittor',\n", - " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", - " train_dataloader=dl_bundle['train'], # train_dataloader,\n", - " evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", " metrics={'acc': Accuracy()}\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "f7c4ac5a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta:    0s \n",
+      "\n",
+      "Compiling Operators(31/31) used: 7.31s eta:    0s \n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.61875,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 99\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 112\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.725,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 116\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.73125,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 117\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tutorials/fastnlp_tutorial_7.ipynb b/tutorials/fastnlp_tutorial_7.ipynb deleted file mode 100644 index 0a7d6922..00000000 --- a/tutorials/fastnlp_tutorial_7.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fdd7ff16", - "metadata": {}, - "source": [ - "# T7. callback 自定义训练过程\n", - "\n", - "  1   \n", - " \n", - "    1.1   \n", - "\n", - "    1.2   \n", - "\n", - "  2   \n", - "\n", - "    2.1   \n", - "\n", - "    2.2   \n", - "\n", - "  3   \n", - "\n", - "    3.1   \n", - "\n", - "    3.2   " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08752c5a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tutorials/fastnlp_tutorial_8.ipynb b/tutorials/fastnlp_tutorial_8.ipynb deleted file mode 100644 index 0664bc41..00000000 --- a/tutorials/fastnlp_tutorial_8.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fdd7ff16", - "metadata": {}, - "source": [ - "# T8. fastNLP 中的文件读取模块\n", - "\n", - "  1   fastNLP 中的 EmbedLoader 模块\n", - " \n", - "    1.1   \n", - "\n", - "    1.2   \n", - "\n", - "  2   fastNLP 中的 Loader 模块\n", - "\n", - "    2.1   \n", - "\n", - "    2.2   \n", - "\n", - "  3   fastNLP 中的 Pipe 模块\n", - "\n", - "    3.1   \n", - "\n", - "    3.2   " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08752c5a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}