Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
637919e45d
6 changed files with 677 additions and 1404 deletions
  1. +1
    -1
      fastNLP/core/controllers/trainer.py
  2. +5
    -5
      tutorials/fastnlp_tutorial_2.ipynb
  3. +193
    -33
      tutorials/fastnlp_tutorial_3.ipynb
  4. +32
    -1291
      tutorials/fastnlp_tutorial_4.ipynb
  5. +149
    -69
      tutorials/fastnlp_tutorial_5.ipynb
  6. +297
    -5
      tutorials/fastnlp_tutorial_6.ipynb

+ 1
- 1
fastNLP/core/controllers/trainer.py View File

@@ -448,7 +448,7 @@ class Trainer(TrainerEventTrigger):
# 初始化 state,包括提供给用户的接口和我们自己使用的接口; # 初始化 state,包括提供给用户的接口和我们自己使用的接口;
self.state = State() self.state = State()
self.trainer_state = TrainerState( self.trainer_state = TrainerState(
n_epochs=n_epochs if n_batches!=-1 else None,
n_epochs=n_epochs if n_batches==-1 else None,
cur_epoch_idx=0, cur_epoch_idx=0,
global_forward_batches=0, global_forward_batches=0,
batch_idx_in_epoch=0, batch_idx_in_epoch=0,


+ 5
- 5
tutorials/fastnlp_tutorial_2.ipynb View File

@@ -801,24 +801,24 @@
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
"\n", "\n",
"# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n", "# 接着,导入数据,先生成为 dataset 形式,再变成 dataset-dict,并转为 databundle 形式\n",
"datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv'))\n",
"datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n",
"train_ds, test_ds = datasets.split(ratio=0.7)\n", "train_ds, test_ds = datasets.split(ratio=0.7)\n",
"data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n", "data_bundle = DataBundle(datasets={'train': train_ds, 'test': test_ds})\n",
"\n", "\n",
"# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n", "# 然后,通过 tokenizer.encode_plus 函数,进行文本分词标注、修改并补充数据包内容\n",
"encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n", "encode = partial(tokenizer.encode_plus, max_length=100, truncation=True,\n",
" return_attention_mask=True)\n", " return_attention_mask=True)\n",
"data_bundle.apply_field_more(encode, field_name='text', progress_bar='tqdm')\n",
"data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
"\n", "\n",
"# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n", "# 在修改好 'text' 字段的文本信息后,接着处理 'label' 字段的预测信息\n",
"target_vocab = Vocabulary(padding=None, unknown=None)\n", "target_vocab = Vocabulary(padding=None, unknown=None)\n",
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
"target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
"target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
" new_field_name='target')\n", " new_field_name='target')\n",
"\n", "\n",
"# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n", "# 最后,通过 data_bundle 的其他一些函数,完成善后内容\n",
"data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", "data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
"data_bundle.set_ignore('label', 'text') \n",
"data_bundle.set_ignore('SentenceId', 'Sentiment', 'Sentence') \n",
"```" "```"
] ]
}, },


+ 193
- 33
tutorials/fastnlp_tutorial_3.ipynb View File

@@ -9,9 +9,9 @@
"\n", "\n",
"  1   fastNLP 中的 dataloader\n", "  1   fastNLP 中的 dataloader\n",
" \n", " \n",
"    1.1   dataloader 的职责描述\n",
"    1.1   dataloader 的基本介绍\n",
"\n", "\n",
"    1.2   dataloader 的基本使用\n",
"    1.2   dataloader 的函数创建\n",
"\n", "\n",
"  2   fastNLP 中 dataloader 的延伸\n", "  2   fastNLP 中 dataloader 的延伸\n",
"\n", "\n",
@@ -27,32 +27,143 @@
"source": [ "source": [
"## 1. fastNLP 中的 dataloader\n", "## 1. fastNLP 中的 dataloader\n",
"\n", "\n",
"### 1.1 dataloader 的职责描述\n",
"### 1.1 dataloader 的基本介绍\n",
"\n", "\n",
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前"
"在`fastNLP 0.8`的开发中,最关键的开发目标就是**实现`fastNLP`对当前主流机器学习框架**,例如\n",
"\n",
"  **较为火热的`pytorch`**,以及**国产的`paddle`和`jittor`的兼容**,扩大受众的同时,也是助力国产\n",
"\n",
"本着分而治之的思想,我们可以将`fastNLP 0.8`对`pytorch`、`paddle`、`jittor`框架的兼容,划分为\n",
"\n",
"    **对数据预处理**、**批量`batch`的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n",
"\n",
"  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
"\n",
"    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n",
"\n",
"    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n",
"\n",
"只有涉及到张量、模型,不同框架才展现出其各自的特色:**`pytorch`中的`tensor`和`nn.Module`**\n",
"\n",
"    **在`paddle`中称为`tensor`和`nn.Layer`**,**在`jittor`中则称为`Var`和`Module`**\n",
"\n",
"    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n",
"\n",
"  针对批量`batch`的处理,作为`fastNLP 0.8`中框架无关部分想框架相关部分的过渡\n",
"\n",
"    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n",
"\n",
"**`dataloader`模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n",
"\n",
"    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n",
"\n",
"    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n",
"\n",
"    第三,**`batch`内数据格式要匹配框架**,**但`batch`结构需保持一致**,**参数匹配机制**\n",
"\n",
"  对此,`fastNLP 0.8`给出了 **`TorchDataLoader`、`PaddleDataLoader`和`JittorDataLoader`**\n",
"\n",
"    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n",
"\n",
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
"|:--|:--:|:--:|:--|:--|\n",
"| **`dataset`** | √ | √ | 指定`dataloader`的数据内容 | |\n",
"| `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
"| `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
"| `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
"| `sampler` | √ | √ | ? | 默认`None` |\n",
"| `batch_sampler` | √ | √ | ? | 默认`None` |\n",
"| `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
"| `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n",
"| `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
"| `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
"| `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
"| `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "eb8fb51c",
"id": "60a8a224",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 1.2 dataloader 的基本使用\n",
"&emsp; 论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n",
"\n", "\n",
"在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,"
"&emsp; &emsp; 包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n",
"\n",
"&emsp; &emsp; 以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "aca72b49", "id": "aca72b49",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
} }
}, },
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
"| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask |\n",
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n",
"| 5 | A comedy-dram... | positive | [101, 1037, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| 2 | This quiet , ... | positive | [101, 2023, 4... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| 1 | A series of e... | negative | [101, 1037, 2... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"| 6 | The Importanc... | neutral | [101, 1996, 5... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... |\n",
"+------------+------------------+-----------+------------------+--------------------+--------------------+\n"
]
}
],
"source": [ "source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"import pandas as pd\n", "import pandas as pd\n",
"from functools import partial\n", "from functools import partial\n",
"from fastNLP.transformers.torch import BertTokenizer\n", "from fastNLP.transformers.torch import BertTokenizer\n",
@@ -63,69 +174,112 @@
"\n", "\n",
"\n", "\n",
"class PipeDemo:\n", "class PipeDemo:\n",
" def __init__(self, tokenizer='bert-base-uncased', num_proc=1):\n",
" def __init__(self, tokenizer='bert-base-uncased'):\n",
" self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n", " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
" self.num_proc = num_proc\n",
"\n", "\n",
" def process_from_file(self, path='./data/test4dataset.tsv'):\n", " def process_from_file(self, path='./data/test4dataset.tsv'):\n",
" datasets = DataSet.from_pandas(pd.read_csv(path))\n",
" datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
" train_ds, test_ds = datasets.split(ratio=0.7)\n", " train_ds, test_ds = datasets.split(ratio=0.7)\n",
" train_ds, dev_ds = datasets.split(ratio=0.8)\n", " train_ds, dev_ds = datasets.split(ratio=0.8)\n",
" data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n", " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
"\n", "\n",
" encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n", " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
" return_attention_mask=True)\n", " return_attention_mask=True)\n",
" data_bundle.apply_field_more(encode, field_name='text', num_proc=self.num_proc)\n",
"\n",
" data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
" \n",
" target_vocab = Vocabulary(padding=None, unknown=None)\n", " target_vocab = Vocabulary(padding=None, unknown=None)\n",
"\n", "\n",
" target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label')\n",
" target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='label',\n",
" target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
" target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
" new_field_name='target')\n", " new_field_name='target')\n",
"\n", "\n",
" data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n", " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
" data_bundle.set_ignore('label', 'text') \n",
" return data_bundle"
" data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n",
" return data_bundle\n",
"\n",
" \n",
"pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
"\n",
"data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "de53bff4",
"id": "76e6b8ab",
"metadata": {}, "metadata": {},
"source": [ "source": [
"&emsp; "
"### 1.2 dataloader 的函数创建\n",
"\n",
"在`fastNLP 0.8`中,**更方便、可能更常用的`dataloader`创建方法是通过`prepare_xx_dataloader`函数**\n",
"\n",
"&emsp; 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n",
"\n",
"&emsp; 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "57a29cb9",
"execution_count": 7,
"id": "5fd60e42",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"pipe = PipeDemo(tokenizer='bert-base-uncased', num_proc=4)\n",
"from fastNLP import prepare_torch_dataloader\n",
"\n", "\n",
"data_bundle = pipe.process_from_file('./data/test4dataset.tsv')"
"train_dataset = data_bundle.get_dataset('train')\n",
"evaluate_dataset = data_bundle.get_dataset('dev')\n",
"\n",
"train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "226bb081",
"id": "7c53f181",
"metadata": {}, "metadata": {},
"source": [ "source": [
"&emsp; "
"```python\n",
"trainer = Trainer(\n",
" model=model,\n",
" train_dataloader=train_dataloader,\n",
" optimizers=optimizer,\n",
"\t...\n",
"\tdriver='torch',\n",
"\tdevice='cuda',\n",
"\t...\n",
" evaluate_dataloaders=evaluate_dataloader, \n",
" metrics={'acc': Accuracy()},\n",
"\t...\n",
")\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "9f457a6e",
"metadata": {},
"source": [
"之所以称`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是`DataSet`类型**,**还可以**\n",
"\n",
"&emsp; **是`DataBundle`类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
"\n",
"&emsp; 例如下方就是**直接通过`prepare_paddle_dataloader`函数生成基于`PaddleDataLoader`的字典**\n",
"\n",
"&emsp; 在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n",
"\n",
"&emsp; &emsp; 这里也可以看出 **`evaluate_dataloaders`的妙处**,一次评测可以针对多个数据集"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "7827557d", "id": "7827557d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from fastNLP import prepare_torch_dataloader\n",
"from fastNLP import prepare_paddle_dataloader\n",
"\n", "\n",
"dl_bundle = prepare_torch_dataloader(data_bundle, batch_size=arg.batch_size)"
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
] ]
}, },
{ {
@@ -133,16 +287,14 @@
"id": "d898cf40", "id": "d898cf40",
"metadata": {}, "metadata": {},
"source": [ "source": [
"&emsp; \n",
"\n",
"```python\n", "```python\n",
"trainer = Trainer(\n", "trainer = Trainer(\n",
" model=model,\n", " model=model,\n",
" train_dataloader=dl_bundle['train'],\n", " train_dataloader=dl_bundle['train'],\n",
" optimizers=optimizer,\n", " optimizers=optimizer,\n",
"\t...\n", "\t...\n",
"\tdriver=\"torch\",\n",
"\tdevice='cuda',\n",
"\tdriver='paddle',\n",
"\tdevice='gpu',\n",
"\t...\n", "\t...\n",
" evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n", " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n",
" metrics={'acc': Accuracy()},\n", " metrics={'acc': Accuracy()},\n",
@@ -187,6 +339,14 @@
"print(type(dl_bundle), type(dl_bundle['train']))" "print(type(dl_bundle), type(dl_bundle['train']))"
] ]
}, },
{
"cell_type": "markdown",
"id": "5f816ef5",
"metadata": {},
"source": [
"&emsp; "
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,


+ 32
- 1291
tutorials/fastnlp_tutorial_4.ipynb
File diff suppressed because it is too large
View File


+ 149
- 69
tutorials/fastnlp_tutorial_5.ipynb View File

@@ -312,6 +312,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import Metric\n", "from fastNLP import Metric\n",
"\n", "\n",
"class MyMetric(Metric):\n", "class MyMetric(Metric):\n",
@@ -333,33 +336,6 @@
" return {'prefix': acc}" " return {'prefix': acc}"
] ]
}, },
{
"cell_type": "markdown",
"id": "af3f8c63",
"metadata": {},
"source": [
"&emsp; 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fd210c5",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP.models.torch import CNNText\n",
"\n",
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "0155f447", "id": "0155f447",
@@ -389,9 +365,9 @@
"id": "e9d81760", "id": "e9d81760",
"metadata": {}, "metadata": {},
"source": [ "source": [
"接着是数据预处理,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
"&emsp; 在数据预处理中,需要注意的是,由于`MyMetric`的`update`函数中,输入参数名称为`pred`和`true`\n",
"\n", "\n",
"&emsp; 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
"&emsp; &emsp; 对应地,需要将数据集中表示预测目标的字段,调整为`true`(预定义的`metric`,应调整为`target`"
] ]
}, },
{ {
@@ -429,14 +405,136 @@
"evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)"
] ]
}, },
{
"cell_type": "markdown",
"id": "af3f8c63",
"metadata": {},
"source": [
"&emsp; 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fd210c5",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models.torch import CNNText\n",
"\n",
"model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
"\n",
"from torch.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-4)"
]
},
{
"cell_type": "markdown",
"id": "6e723b87",
"metadata": {},
"source": [
"## 3. fastNLP 中 trainer 的补充介绍\n",
"\n",
"### 3.1 trainer 的内部结构\n",
"\n",
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n",
"\n",
"&emsp; 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n",
"\n",
"| <div align=\"center\">名称</div> | <div align=\"center\">参数</div> | <div align=\"center\">属性</div> | <div align=\"center\">功能</div> | <div align=\"center\">内容</div> |\n",
"|:--|:--:|:--:|:--|:--|\n",
"| **`model`** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n",
"| **`driver`** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n",
"| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n",
"| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n",
"| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n",
"| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n",
"| **`optimizers`** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n",
"| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n",
"| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n",
"| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n",
"| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n",
"| **`train_dataloader`** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n",
"| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n",
"| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n",
"| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n",
"| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n",
"| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n",
"| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n",
"| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n",
"| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n",
"| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n",
"| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n",
"| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n",
"| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n",
"| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n",
"| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n",
"| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |"
]
},
{
"cell_type": "markdown",
"id": "2fc8b9f3",
"metadata": {},
"source": [
"&emsp; 以及`trainer`模块内部的基础方法,相关进阶操作,如“`on`系列函数”、`callback`控制,请参考后续的`tutorial-7`\n",
"\n",
"| <div align=\"center\">名称</div> |<div align=\"center\">功能</div> | <div align=\"center\">主要参数</div> |\n",
"|:--|:--|:--|\n",
"| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n",
"| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n",
"| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n",
"| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n",
"| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n",
"| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n",
"| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n",
"| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n",
"| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n",
"| `save_checkpoint` | <div style=\"line-height:25px;\">保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar`</div> | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n",
"| `load_checkpoint` | <div style=\"line-height:25px;\">加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler`<br>和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar`</div> | <div style=\"line-height:25px;\">`folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True`<br>`resume_training`指明是否只精确到上次训练的批量,默认`True`</div> |\n",
"| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n",
"| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n",
"\n",
"<!-- ```python\n",
"Trainer.__init__():\n",
"\ton_after_trainer_initialized(trainer, driver)\n",
"Trainer.run():\n",
"\tif num_eval_sanity_batch > 0: # 如果设置了 num_eval_sanity_batch\n",
"\t\ton_sanity_check_begin(trainer)\n",
"\t\ton_sanity_check_end(trainer, sanity_check_res)\n",
"\ttry:\n",
"\t\ton_train_begin(trainer)\n",
"\t\twhile cur_epoch_idx < n_epochs:\n",
"\t\t\ton_train_epoch_begin(trainer)\n",
"\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n",
"\t\t\t\ton_fetch_data_begin(trainer)\n",
"\t\t\t\tbatch = next(dataloader)\n",
"\t\t\t\ton_fetch_data_end(trainer)\n",
"\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n",
"\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n",
"\t\t\t\ton_after_backward(trainer)\n",
"\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n",
"\t\t\t\ton_train_batch_end(trainer)\n",
"\t\t\ton_train_epoch_end(trainer)\n",
"\texcept BaseException:\n",
"\t\tself.on_exception(trainer, exception)\n",
"\tfinally:\n",
"\t\ton_train_end(trainer)\n",
"``` -->"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1e21df35", "id": "1e21df35",
"metadata": {}, "metadata": {},
"source": [ "source": [
"然后就是初始化`trainer`实例,其中`metrics`变量输入的键值对,字串`'suffix'`和之前定义的字串`'prefix'`\n",
"紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n",
"\n", "\n",
"&emsp; 将拼接在一起显示到`trainer`的`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
"&emsp; 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`"
] ]
}, },
{ {
@@ -462,61 +560,43 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "6e723b87",
"metadata": {},
"id": "b1b2e8b7",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"source": [ "source": [
"## 3. fastNLP 中 trainer 的补充介绍\n",
"\n",
"### 3.1 trainer 的内部结构\n",
"\n",
"在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经\n",
"\n",
"&emsp; 展示了很多关于`trainer`的使用案例,以下我们先补充介绍训练模块`trainer`的一些内部结构\n",
"\n",
"\n",
"\n",
"'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n",
"'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n",
"'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n",
"'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n",
"'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n",
"'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n",
"'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n",
"'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n",
"'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n",
"'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n",
"'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n",
"'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n",
"'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n",
"'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n",
"'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n",
"'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n",
"'trainer_state', 'zero_grad'\n",
"最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n",
"\n", "\n",
"&emsp; run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)"
"| <div align=\"center\">名称</div> | <div align=\"center\">功能</div> | <div align=\"center\">默认值</div> |\n",
"|:--|:--|:--|\n",
"| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n",
"| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n",
"| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n",
"| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n",
"| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "c348864c",
"id": "43be274f",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
} }
}, },
"outputs": [], "outputs": [],
"source": []
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "43be274f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"id": "f1abfa0a",
"metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
} }


+ 297
- 5
tutorials/fastnlp_tutorial_6.ipynb View File

@@ -19,20 +19,312 @@
"\n", "\n",
"&emsp; &emsp; 2.2 &ensp; 使用 jittor 搭建并训练模型\n", "&emsp; &emsp; 2.2 &ensp; 使用 jittor 搭建并训练模型\n",
"\n", "\n",
"&emsp; 3 &ensp; fastNLP 实现 paddle 与 pytorch 互转\n",
"&emsp; 3 &ensp; fastNLP 实现 paddle 与 pytorch 互转"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08752c5a",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"sst2data = load_dataset('glue', 'sst2')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e8cc210",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"\n",
"from fastNLP import DataSet\n",
"\n",
"dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n",
"\n", "\n",
"&emsp; &emsp; 3.1 &ensp; \n",
"dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n",
" progress_bar=\"tqdm\")\n",
"dataset.delete_field('sentence')\n",
"dataset.delete_field('label')\n",
"dataset.delete_field('idx')\n",
"\n", "\n",
"&emsp; &emsp; 3.2 &ensp; "
"from fastNLP import Vocabulary\n",
"\n",
"vocab = Vocabulary()\n",
"vocab.from_dataset(dataset, field_name='words')\n",
"vocab.index_dataset(dataset, field_name='words')\n",
"\n",
"train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n",
"print(type(train_dataset), isinstance(train_dataset, DataSet))\n",
"\n",
"from fastNLP.io import DataBundle\n",
"\n",
"data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})"
]
},
{
"cell_type": "markdown",
"id": "57a3272f",
"metadata": {},
"source": [
"## 1. fastNLP 结合 paddle 训练模型\n",
"\n",
"```python\n",
"import paddle\n",
"\n",
"lstm = paddle.nn.LSTM(16, 32, 2)\n",
"\n",
"x = paddle.randn((4, 23, 16))\n",
"h = paddle.randn((2, 4, 32))\n",
"c = paddle.randn((2, 4, 32))\n",
"\n",
"y, (h, c) = lstm(x, (h, c))\n",
"\n",
"print(y.shape) # [4, 23, 32]\n",
"print(h.shape) # [2, 4, 32]\n",
"print(c.shape) # [2, 4, 32]\n",
"```"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "08752c5a",
"id": "e31b3198",
"metadata": {},
"outputs": [],
"source": [
"import paddle\n",
"import paddle.nn as nn\n",
"\n",
"\n",
"class ClsByPaddle(nn.Layer):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
" nn.Layer.__init__(self)\n",
" self.hidden_dim = hidden_dim\n",
"\n",
" self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
" # self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n",
" # num_layers=num_layers, direction='bidirectional', dropout=dropout)\n",
" self.mlp = nn.Sequential(('linear_1', nn.Linear(hidden_dim * 2, hidden_dim * 2)),\n",
" ('activate', nn.ReLU()),\n",
" ('linear_2', nn.Linear(hidden_dim * 2, output_dim)))\n",
" \n",
" self.loss_fn = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, words):\n",
" output = self.embedding(words)\n",
" # output, (hidden, cell) = self.lstm(output)\n",
" hidden = paddle.randn((2, words.shape[0], self.hidden_dim))\n",
" output = self.mlp(paddle.concat((hidden[-1], hidden[-2]), axis=1))\n",
" return output\n",
" \n",
" def train_step(self, words, target):\n",
" pred = self(words)\n",
" return {\"loss\": self.loss_fn(pred, target)}\n",
"\n",
" def evaluate_step(self, words, target):\n",
" pred = self(words)\n",
" pred = paddle.max(pred, axis=-1)[1]\n",
" return {\"pred\": pred, \"target\": target}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c63b030f",
"metadata": {},
"outputs": [],
"source": [
"model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
"\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2997c0aa",
"metadata": {},
"outputs": [],
"source": [
"from paddle.optimizer import AdamW\n",
"\n",
"optimizers = AdamW(parameters=model.parameters(), learning_rate=1e-2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ead35fb8",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import prepare_paddle_dataloader\n",
"\n",
"# train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"# evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n",
"\n",
"dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25e8da83",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": []
"source": [
"from fastNLP import Trainer, Accuracy\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='paddle',\n",
" device='gpu', # 'cpu', 'gpu', 'gpu:x'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=dl_bundle['train'], # train_dataloader,\n",
" evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n",
" metrics={'acc': Accuracy()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d63c5d74",
"metadata": {},
"outputs": [],
"source": [
"trainer.run(num_eval_batch_per_dl=10) # 然后卡了?"
]
},
{
"cell_type": "markdown",
"id": "cb9a0b3c",
"metadata": {},
"source": [
"## 2. fastNLP 结合 jittor 训练模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c600191d",
"metadata": {},
"outputs": [],
"source": [
"import jittor\n",
"import jittor.nn as nn\n",
"\n",
"from jittor import Module\n",
"\n",
"\n",
"class ClsByJittor(Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n",
" Module.__init__(self)\n",
" self.hidden_dim = hidden_dim\n",
"\n",
" self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n",
" self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, \n",
" num_layers=num_layers, bidirectional=True, dropout=dropout)\n",
" self.mlp = nn.Sequential([nn.Linear(hidden_dim * 2, hidden_dim * 2),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_dim * 2, output_dim)])\n",
"\n",
" self.loss_fn = nn.BCELoss()\n",
"\n",
" def execute(self, words):\n",
" output = self.embedding(words)\n",
" output, (hidden, cell) = self.lstm(output)\n",
" # hidden = jittor.randn((2, words.shape[0], self.hidden_dim))\n",
" output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), axis=1))\n",
" return output\n",
" \n",
" def train_step(self, words, target):\n",
" pred = self(words)\n",
" return {\"loss\": self.loss_fn(pred, target)}\n",
"\n",
" def evaluate_step(self, words, target):\n",
" pred = self(words)\n",
" pred = jittor.max(pred, axis=-1)[1]\n",
" return {\"pred\": pred, \"target\": target}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a94ed8c4",
"metadata": {},
"outputs": [],
"source": [
"model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n",
"\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d15ebc1",
"metadata": {},
"outputs": [],
"source": [
"from jittor.optim import AdamW\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=1e-2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95d8d09e",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import prepare_jittor_dataloader\n",
"\n",
"# train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
"# evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n",
"\n",
"dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "917eab81",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer, Accuracy\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" driver='jittor',\n",
" device='gpu', # 'cpu', 'gpu', 'cuda'\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=dl_bundle['train'], # train_dataloader,\n",
" evaluate_dataloaders=dl_bundle['dev'], # evaluate_dataloader,\n",
" metrics={'acc': Accuracy()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7c4ac5a",
"metadata": {},
"outputs": [],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
} }
], ],
"metadata": { "metadata": {


Loading…
Cancel
Save