@@ -17,7 +17,7 @@ | |||
"\n", | |||
"    2.1   collator 的概念与使用\n", | |||
"\n", | |||
"    2.2   sampler 的概念与使用" | |||
"    2.2   结合 datasets 框架" | |||
] | |||
}, | |||
{ | |||
@@ -71,8 +71,8 @@ | |||
"| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", | |||
"| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", | |||
"| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", | |||
"| `sampler` | √ | √ | ? | 默认`None` |\n", | |||
"| `batch_sampler` | √ | √ | ? | 默认`None` |\n", | |||
"| `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", | |||
"| `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", | |||
"| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", | |||
"| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", | |||
"| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", | |||
@@ -95,7 +95,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"execution_count": 1, | |||
"id": "aca72b49", | |||
"metadata": { | |||
"pycharm": { | |||
@@ -103,6 +103,26 @@ | |||
} | |||
}, | |||
"outputs": [ | |||
{ | |||
"name": "stderr", | |||
"output_type": "stream", | |||
"text": [ | |||
"\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
@@ -149,14 +169,14 @@ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n", | |||
"| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n", | |||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n", | |||
"| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", | |||
"| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", | |||
"| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", | |||
"| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n", | |||
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n" | |||
"+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", | |||
"| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n", | |||
"+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", | |||
"| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", | |||
"| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n", | |||
"| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", | |||
"| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n", | |||
"+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n" | |||
] | |||
} | |||
], | |||
@@ -200,7 +220,9 @@ | |||
" \n", | |||
"pipe = PipeDemo(tokenizer='bert-base-uncased')\n", | |||
"\n", | |||
"data_bundle = pipe.process_from_file('./data/test4dataset.tsv')" | |||
"data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n", | |||
"\n", | |||
"print(data_bundle.get_dataset('train'))" | |||
] | |||
}, | |||
{ | |||
@@ -214,15 +236,65 @@ | |||
"\n", | |||
"  例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n", | |||
"\n", | |||
"  类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`" | |||
"  类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n", | |||
"\n", | |||
"同时我们看还可以发现,在`fastNLP 0.8`中,**`batch`表示为字典`dict`类型**,**`key`值就是原先数据集中各个字段**\n", | |||
"\n", | |||
"  **除去经过`DataBundle.set_ignore`函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 7, | |||
"execution_count": 2, | |||
"id": "5fd60e42", | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n", | |||
"<class 'dict'> <class 'torch.Tensor'> ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n", | |||
"{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", | |||
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n", | |||
" 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", | |||
" 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", | |||
" 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0],\n", | |||
" [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", | |||
" 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", | |||
" 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", | |||
" 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", | |||
" 1037, 2466, 1012, 102],\n", | |||
" [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", | |||
" 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", | |||
" 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0],\n", | |||
" [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", | |||
" 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", | |||
" 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", | |||
" 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0]]),\n", | |||
" 'target': tensor([0, 1, 1, 2]),\n", | |||
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import prepare_torch_dataloader\n", | |||
"\n", | |||
@@ -230,28 +302,15 @@ | |||
"evaluate_dataset = data_bundle.get_dataset('dev')\n", | |||
"\n", | |||
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", | |||
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "7c53f181", | |||
"metadata": {}, | |||
"source": [ | |||
"```python\n", | |||
"trainer = Trainer(\n", | |||
" model=model,\n", | |||
" train_dataloader=train_dataloader,\n", | |||
" optimizers=optimizer,\n", | |||
"\t...\n", | |||
"\tdriver='torch',\n", | |||
"\tdevice='cuda',\n", | |||
"\t...\n", | |||
" evaluate_dataloaders=evaluate_dataloader, \n", | |||
" metrics={'acc': Accuracy()},\n", | |||
"\t...\n", | |||
")\n", | |||
"```" | |||
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n", | |||
"\n", | |||
"print(type(train_dataloader))\n", | |||
"\n", | |||
"import pprint\n", | |||
"\n", | |||
"for batch in train_dataloader:\n", | |||
" print(type(batch), type(batch['input_ids']), list(batch))\n", | |||
" pprint.pprint(batch, width=1)" | |||
] | |||
}, | |||
{ | |||
@@ -259,27 +318,33 @@ | |||
"id": "9f457a6e", | |||
"metadata": {}, | |||
"source": [ | |||
"之所以称`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n", | |||
"之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n", | |||
"\n", | |||
"  **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n", | |||
"\n", | |||
"  例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n", | |||
"\n", | |||
"  在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", | |||
"\n", | |||
"    这里也可以看出 **`evaluate_dataloaders`的妙处**,一次评测可以针对多个数据集" | |||
"例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 6, | |||
"execution_count": 3, | |||
"id": "7827557d", | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"<class 'fastNLP.core.dataloaders.paddle_dataloader.fdl.PaddleDataLoader'>\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import prepare_paddle_dataloader\n", | |||
"\n", | |||
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" | |||
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n", | |||
"\n", | |||
"print(type(dl_bundle['train']))" | |||
] | |||
}, | |||
{ | |||
@@ -287,6 +352,10 @@ | |||
"id": "d898cf40", | |||
"metadata": {}, | |||
"source": [ | |||
"  而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", | |||
"\n", | |||
"  这里也可以看出`trainer`模块中,**`evaluate_dataloaders`的设计允许评测可以针对多个数据集**\n", | |||
"\n", | |||
"```python\n", | |||
"trainer = Trainer(\n", | |||
" model=model,\n", | |||
@@ -312,31 +381,45 @@ | |||
"\n", | |||
"### 2.1 collator 的概念与使用\n", | |||
"\n", | |||
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n", | |||
"在`fastNLP 0.8`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n", | |||
"\n", | |||
"  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n", | |||
"  例如,**实现序列的补零对齐的核对器`collator`模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n", | |||
"\n", | |||
"  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n", | |||
"在`fastNLP 0.8`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n", | |||
"\n", | |||
"在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过" | |||
"| <div align=\"center\">名称</div> | <div align=\"center\">属性</div> | <div align=\"center\">方法</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n", | |||
"|:--|:--:|:--:|:--|:--|\n", | |||
"| `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n", | |||
"| `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐  | 字典类型 |\n", | |||
"| `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n", | |||
"| `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n", | |||
"| `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n", | |||
"| `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name`  |\n", | |||
"| `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "651baef6", | |||
"execution_count": 4, | |||
"id": "d0795b3e", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"<class 'function'>\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import prepare_torch_dataloader\n", | |||
"train_dataloader.collate_fn\n", | |||
"\n", | |||
"dl_bundle = prepare_torch_dataloader(data_bundle, train_batch_size=2)\n", | |||
"\n", | |||
"print(type(dl_bundle), type(dl_bundle['train']))" | |||
"print(type(train_dataloader.collate_fn))" | |||
] | |||
}, | |||
{ | |||
@@ -344,80 +427,165 @@ | |||
"id": "5f816ef5", | |||
"metadata": {}, | |||
"source": [ | |||
"  " | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "726ba357", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"outputs": [], | |||
"source": [ | |||
"dataloader = prepare_torch_dataloader(datasets['train'], train_batch_size=2)\n", | |||
"print(type(dataloader))\n", | |||
"print(dir(dataloader))" | |||
"此外,还可以**手动定义`dataloader`中的`collate_fn`**,而不是使用`fastNLP 0.8`中自带的`collator`模块\n", | |||
"\n", | |||
"  该函数的定义可以大致如下,需要注意的是,**定义`collate_fn`之前需要了解`batch`作为字典的格式**\n", | |||
"\n", | |||
"  该函数通过`collate_fn`参数传入`dataloader`,**在`batch`分发**(**而不是`batch`划分**)**时调用**" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "d0795b3e", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"execution_count": 5, | |||
"id": "ff8e405e", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"dataloader.collate_fn" | |||
"import torch\n", | |||
"\n", | |||
"def collate_fn(batch):\n", | |||
" input_ids, atten_mask, labels = [], [], []\n", | |||
" max_length = [0] * 3\n", | |||
" for each_item in batch:\n", | |||
" input_ids.append(each_item['input_ids'])\n", | |||
" max_length[0] = max(len(each_item['input_ids']), max_length[0])\n", | |||
" atten_mask.append(each_item['token_type_ids'])\n", | |||
" max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n", | |||
" labels.append(each_item['attention_mask'])\n", | |||
" max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n", | |||
"\n", | |||
" for i in range(3):\n", | |||
" each = (input_ids, atten_mask, labels)[i]\n", | |||
" for item in each:\n", | |||
" item.extend([0] * (max_length[i] - len(item)))\n", | |||
" return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", | |||
" 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", | |||
" 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "f9bbd9a7", | |||
"id": "487b75fb", | |||
"metadata": {}, | |||
"source": [ | |||
"### 2.2 sampler 的概念与使用" | |||
"注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "b0c3c58d", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
"execution_count": 6, | |||
"id": "e916d1ac", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n", | |||
"<class 'function'>\n", | |||
"{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |||
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0]),\n", | |||
" 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", | |||
" 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", | |||
" 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0],\n", | |||
" [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", | |||
" 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", | |||
" 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", | |||
" 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", | |||
" 1037, 2466, 1012, 102],\n", | |||
" [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", | |||
" 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", | |||
" 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0],\n", | |||
" [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", | |||
" 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", | |||
" 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", | |||
" 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0]]),\n", | |||
" 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |||
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", | |||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" | |||
] | |||
} | |||
}, | |||
"outputs": [], | |||
], | |||
"source": [ | |||
"dataloader.batch_sampler" | |||
"train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n", | |||
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n", | |||
"\n", | |||
"print(type(train_dataloader))\n", | |||
"print(type(train_dataloader.collate_fn))\n", | |||
"\n", | |||
"for batch in train_dataloader:\n", | |||
" pprint.pprint(batch, width=1)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "51bf0878", | |||
"id": "0bd98365", | |||
"metadata": {}, | |||
"source": [ | |||
"  " | |||
"### 2.2 fastNLP 与 datasets 的结合\n", | |||
"\n", | |||
"从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v0.8`数据读取、预处理、加载,整个流程的介绍\n", | |||
"\n", | |||
"  不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n", | |||
"\n", | |||
"**使用`datasets`模块中的`load_dataset`函数**,通过指定数据集两级的名称,示例中即是**`GLUE`标准中的`SST-2`数据集**\n", | |||
"\n", | |||
"  即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n", | |||
"\n", | |||
"  之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "3fd2486f", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
"execution_count": 7, | |||
"id": "91879c30", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stderr", | |||
"output_type": "stream", | |||
"text": [ | |||
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
" 0%| | 0/3 [00:00<?, ?it/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
}, | |||
"outputs": [], | |||
"source": [] | |||
], | |||
"source": [ | |||
"from datasets import load_dataset\n", | |||
"\n", | |||
"sst2data = load_dataset('glue', 'sst2')\n", | |||
"\n", | |||
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())" | |||
] | |||
} | |||
], | |||
"metadata": { | |||
@@ -296,7 +296,7 @@ | |||
"\n", | |||
"    在`fastNLP v0.8`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", | |||
"\n", | |||
"    此处刻意调整为:`pred`,对应预测值,和模型输出一致;`true`,对应真实值,数据集字段需要调整\n", | |||
"    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n", | |||
"\n", | |||
"  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", | |||
"\n", | |||
@@ -307,10 +307,24 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 1, | |||
"id": "08a872e9", | |||
"metadata": {}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"import sys\n", | |||
"sys.path.append('..')\n", | |||
@@ -320,16 +334,16 @@ | |||
"class MyMetric(Metric):\n", | |||
"\n", | |||
" def __init__(self):\n", | |||
" MyMetric.__init__(self)\n", | |||
" Metric.__init__(self)\n", | |||
" self.total_num = 0\n", | |||
" self.right_num = 0\n", | |||
"\n", | |||
" def update(self, pred, true):\n", | |||
" def update(self, pred, target):\n", | |||
" self.total_num += target.size(0)\n", | |||
" self.right_num += target.eq(pred).sum().item()\n", | |||
"\n", | |||
" def get_metric(self, reset=True):\n", | |||
" acc = self.acc_count / self.total_num\n", | |||
" acc = self.right_num / self.total_num\n", | |||
" if reset:\n", | |||
" self.total_num = 0\n", | |||
" self.right_num = 0\n", | |||
@@ -346,14 +360,36 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 2, | |||
"id": "5ad81ac7", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"name": "stderr", | |||
"output_type": "stream", | |||
"text": [ | |||
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |||
] | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "ef923b90b19847f4916cccda5d33fc36", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
" 0%| | 0/3 [00:00<?, ?it/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"from datasets import load_dataset\n", | |||
"\n", | |||
@@ -365,30 +401,43 @@ | |||
"id": "e9d81760", | |||
"metadata": {}, | |||
"source": [ | |||
"  在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n", | |||
"  在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n", | |||
"\n", | |||
"    对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`" | |||
"    数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 3, | |||
"id": "cfb28b1b", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Processing: 0%| | 0/6000 [00:00<?, ?it/s]" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"from fastNLP import DataSet\n", | |||
"\n", | |||
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", | |||
"\n", | |||
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'true': ins['label']}, \n", | |||
" progress_bar=\"tqdm\")\n", | |||
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n", | |||
"dataset.delete_field('sentence')\n", | |||
"dataset.delete_field('label')\n", | |||
"dataset.delete_field('idx')\n", | |||
"\n", | |||
"from fastNLP import Vocabulary\n", | |||
@@ -415,7 +464,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 4, | |||
"id": "2fd210c5", | |||
"metadata": {}, | |||
"outputs": [], | |||
@@ -445,10 +494,10 @@ | |||
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n", | |||
"|:--|:--:|:--:|:--|:--|\n", | |||
"| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n", | |||
"| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", | |||
"| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", | |||
"| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n", | |||
"| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n", | |||
"| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", | |||
"| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", | |||
"| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n", | |||
"| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n", | |||
"| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n", | |||
@@ -473,12 +522,34 @@ | |||
"| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "9e13ee08", | |||
"metadata": {}, | |||
"source": [ | |||
"其中,**`input_mapping`和`output_mapping`** 定义形式如下:输入字典形式的数据,根据参数匹配要求\n", | |||
"\n", | |||
"  调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 5, | |||
"id": "de96c1d1", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [ | |||
"def input_mapping(data):\n", | |||
" data['target'] = data['label']\n", | |||
" return data" | |||
] | |||
}, | |||
{ | |||
"cell_type": "markdown", | |||
"id": "2fc8b9f3", | |||
"metadata": {}, | |||
"source": [ | |||
"  以及`trainer`模块内部的基础方法,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n", | |||
"  而`trainer`模块的基础方法列表如下,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n", | |||
"\n", | |||
"| <div align=\"center\">名称</div> |<div align=\"center\">功能</div> | <div align=\"center\">主要参数</div> |\n", | |||
"|:--|:--|:--|\n", | |||
@@ -539,7 +610,7 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 6, | |||
"id": "926a9c50", | |||
"metadata": {}, | |||
"outputs": [], | |||
@@ -552,6 +623,7 @@ | |||
" device=0, # 'cuda'\n", | |||
" n_epochs=10,\n", | |||
" optimizers=optimizers,\n", | |||
" input_mapping=input_mapping,\n", | |||
" train_dataloader=train_dataloader,\n", | |||
" evaluate_dataloaders=evaluate_dataloader,\n", | |||
" metrics={'suffix': MyMetric()}\n", | |||
@@ -580,14 +652,557 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"execution_count": 7, | |||
"id": "43be274f", | |||
"metadata": { | |||
"pycharm": { | |||
"name": "#%%\n" | |||
} | |||
}, | |||
"outputs": [], | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[09:30:35] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../fastNLP/core/controllers/trainer.py\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../fastNLP/core/controllers/trainer.py#596\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">596</span></a>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Output()" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", | |||
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", | |||
".get_parent()\n", | |||
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", | |||
"output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", | |||
".get_parent()\n", | |||
" if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", | |||
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", | |||
".get_parent()\n", | |||
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", | |||
"output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", | |||
".get_parent()\n", | |||
" self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"application/vnd.jupyter.widget-view+json": { | |||
"model_id": "", | |||
"version_major": 2, | |||
"version_minor": 0 | |||
}, | |||
"text/plain": [ | |||
"Output()" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6875</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.825</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8125</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">6</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">7</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">----------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">---------------------------- Eval. results on Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span>, Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> -----------------------------\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n", | |||
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"prefix#suffix\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.80625</span>\n", | |||
"<span style=\"font-weight: bold\">}</span>\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\u001b[1m{\u001b[0m\n", | |||
" \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", | |||
"\u001b[1m}\u001b[0m\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |||
], | |||
"text/plain": [] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
}, | |||
{ | |||
"data": { | |||
"text/html": [ | |||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n", | |||
"</pre>\n" | |||
], | |||
"text/plain": [ | |||
"\n" | |||
] | |||
}, | |||
"metadata": {}, | |||
"output_type": "display_data" | |||
} | |||
], | |||
"source": [ | |||
"trainer.run(num_eval_batch_per_dl=10)" | |||
] | |||
@@ -1,59 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"id": "fdd7ff16", | |||
"metadata": {}, | |||
"source": [ | |||
"# T7. callback 自定义训练过程\n", | |||
"\n", | |||
"  1   \n", | |||
" \n", | |||
"    1.1   \n", | |||
"\n", | |||
"    1.2   \n", | |||
"\n", | |||
"  2   \n", | |||
"\n", | |||
"    2.1   \n", | |||
"\n", | |||
"    2.2   \n", | |||
"\n", | |||
"  3   \n", | |||
"\n", | |||
"    3.1   \n", | |||
"\n", | |||
"    3.2   " | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "08752c5a", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 3 (ipykernel)", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.7.13" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 5 | |||
} |
@@ -1,59 +0,0 @@ | |||
{ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"id": "fdd7ff16", | |||
"metadata": {}, | |||
"source": [ | |||
"# T8. fastNLP 中的文件读取模块\n", | |||
"\n", | |||
"  1   fastNLP 中的 EmbedLoader 模块\n", | |||
" \n", | |||
"    1.1   \n", | |||
"\n", | |||
"    1.2   \n", | |||
"\n", | |||
"  2   fastNLP 中的 Loader 模块\n", | |||
"\n", | |||
"    2.1   \n", | |||
"\n", | |||
"    2.2   \n", | |||
"\n", | |||
"  3   fastNLP 中的 Pipe 模块\n", | |||
"\n", | |||
"    3.1   \n", | |||
"\n", | |||
"    3.2   " | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "08752c5a", | |||
"metadata": {}, | |||
"outputs": [], | |||
"source": [] | |||
} | |||
], | |||
"metadata": { | |||
"kernelspec": { | |||
"display_name": "Python 3 (ipykernel)", | |||
"language": "python", | |||
"name": "python3" | |||
}, | |||
"language_info": { | |||
"codemirror_mode": { | |||
"name": "ipython", | |||
"version": 3 | |||
}, | |||
"file_extension": ".py", | |||
"mimetype": "text/x-python", | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.7.13" | |||
} | |||
}, | |||
"nbformat": 4, | |||
"nbformat_minor": 5 | |||
} |