Browse Source

添加 fastnlp_torch_tutorial

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
a862c31b71
1 changed files with 869 additions and 0 deletions
  1. +869
    -0
      docs/source/tutorials/fastnlp_torch_tutorial.ipynb

+ 869
- 0
docs/source/tutorials/fastnlp_torch_tutorial.ipynb View File

@@ -0,0 +1,869 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6011adf8",
"metadata": {},
"source": [
"# 10 分钟快速上手 fastNLP torch\n",
"\n",
"在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e166c051",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n",
"Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n",
"Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n",
"WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n",
" Issued certificate has expired.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 982975 (960K) [application/x-zip-compressed]\n",
"Saving to: ‘conll2003.zip’\n",
"\n",
"conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n",
"\n",
"2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n",
"\n",
"Archive: conll2003.zip\n",
" inflating: conll2003/metadata \n",
" inflating: conll2003/test.txt \n",
" inflating: conll2003/train.txt \n",
" inflating: conll2003/valid.txt \n"
]
}
],
"source": [
"# Linux/Mac 下载数据,并解压\n",
"import platform\n",
"if platform.system() != \"Windows\":\n",
" !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n",
" !unzip conll2003.zip -d conll2003\n",
"# Windows用户请通过复制该url到浏览器下载该数据并解压"
]
},
{
"cell_type": "markdown",
"id": "f7acbf1f",
"metadata": {},
"source": [
"## 目录\n",
"接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n",
"- 1. 数据加载\n",
"- 2. 数据预处理、数据缓存\n",
"- 3. DataLoader\n",
"- 4. 模型准备\n",
"- 5. Trainer的使用\n",
"- 6. Evaluator的使用\n",
"- 7. 其它【待补充】\n",
" - 7.1 使用多卡进行训练、评测\n",
" - 7.2 使用ZeRO优化\n",
" - 7.3 通过overfit测试快速验证模型\n",
" - 7.4 复杂Monitor的使用\n",
" - 7.5 训练过程中,使用不同的测试函数\n",
" - 7.6 更有效率的Sampler\n",
" - 7.7 保存模型\n",
" - 7.8 断点重训\n",
" - 7.9 使用huggingface datasets\n",
" - 7.10 使用torchmetrics来作为metric\n",
" - 7.11 将预测结果写出到文件\n",
" - 7.12 混合 dataset 训练\n",
" - 7.13 logger的使用\n",
" - 7.14 自定义分布式 Metric 。\n",
" - 7.15 通过batch_step_fn实现R-Drop"
]
},
{
"cell_type": "markdown",
"id": "0657dfba",
"metadata": {},
"source": [
"#### 1. 数据加载\n",
"目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c557f0ba",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f59e438",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\ttrain has 14987 instances.\n",
"\ttest has 3684 instances.\n",
"\tdev has 3466 instances.\n",
"\n"
]
}
],
"source": [
"from fastNLP import DataSet, Instance\n",
"from fastNLP.io import Loader\n",
"\n",
"\n",
"# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n",
"class ConllLoader(Loader):\n",
" def _load(self, path):\n",
" ds = DataSet()\n",
" with open(path, 'r') as f:\n",
" segments = []\n",
" for line in f:\n",
" line = line.strip()\n",
" if line == '': # 如果为空行,说明需要切换到下一句了。\n",
" if segments:\n",
" raw_words = [s[0] for s in segments]\n",
" raw_target = [s[1] for s in segments]\n",
" # 将一个 sample 插入到 DataSet中\n",
" ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n",
" segments = []\n",
" else:\n",
" parts = line.split()\n",
" assert len(parts)==4\n",
" segments.append([parts[0], parts[-1]])\n",
" return ds\n",
" \n",
"\n",
"# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n",
"# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n",
"data_bundle = ConllLoader().load({\n",
" 'train': 'conll2003/train.txt',\n",
" 'test': 'conll2003/test.txt',\n",
" 'dev': 'conll2003/valid.txt'\n",
"})\n",
"\"\"\"\n",
"也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n",
"'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n",
"此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n",
"该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n",
"\"\"\"\n",
"\n",
"print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n",
"# data_bundle.get_dataset('train') # 可以获取单个 dataset"
]
},
{
"cell_type": "markdown",
"id": "57ae314d",
"metadata": {},
"source": [
"#### 2. 数据预处理\n",
"接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n",
"(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n",
"(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "96389988",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3bd41a323c94a41b409d29a5d4079b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:13] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Save cache to <span style=\"color: #800080; text-decoration-color: #800080\">/remote-home/hyan01/exps/fastNLP/fastN</span> <a href=\"file://../../fastNLP/core/utils/cache_results.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">cache_results.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/utils/cache_results.py#332\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">332</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #800080; text-decoration-color: #800080\">LP/demo/torch_tutorial/caches/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">c7f74559_cache.pkl.</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n",
"from fastNLP.transformers.torch import BertTokenizer\n",
"from fastNLP import cache_results, Vocabulary\n",
"\n",
"# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n",
"# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n",
"# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n",
"# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n",
"@cache_results('caches/cache.pkl')\n",
"def process_data(data_bundle, model_name):\n",
" tokenizer = BertTokenizer.from_pretrained(model_name)\n",
" def bpe(raw_words):\n",
" bpes = [tokenizer.cls_token_id]\n",
" first = [0]\n",
" first_index = 1 # 记录第一个bpe的位置\n",
" for word in raw_words:\n",
" bpe = tokenizer.encode(word, add_special_tokens=False)\n",
" bpes.extend(bpe)\n",
" first.append(first_index)\n",
" first_index += len(bpe)\n",
" bpes.append(tokenizer.sep_token_id)\n",
" first.append(first_index)\n",
" return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n",
" # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n",
" data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n",
" # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n",
" # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n",
" # 更多信息可以参考对应的文档。\n",
" \n",
" # tag的词表,由于这是词表,所以不需要有padding和unk\n",
" tag_vocab = Vocabulary(padding=None, unknown=None)\n",
" # 从 train 数据的 raw_target 中获取建立词表\n",
" tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n",
" # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n",
" tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n",
" \n",
" # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n",
" data_bundle.set_vocab(tag_vocab, field_name='target')\n",
" \n",
" return data_bundle, tokenizer\n",
"\n",
"data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n",
"# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 "
]
},
{
"cell_type": "markdown",
"id": "80036fcd",
"metadata": {},
"source": [
"### 3. DataLoader \n",
"由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "09494695",
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import prepare_dataloader\n",
"\n",
"# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n",
"# fastNLP.TorchDataLoader 对象。\n",
"dls = prepare_dataloader(data_bundle, batch_size=24) \n",
"\n",
"\n",
"# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n",
"# 默认使用 0 进行 pad 。\n",
"for dl in dls.values():\n",
" # 可以通过 set_pad 修改 padding 的行为。\n",
" dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n",
" # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n",
" dl.set_ignore('raw_target')\n",
" dl.set_pad('target', pad_val=-100)\n",
"# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n",
"# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n",
"# DataSet 也支持这两个方法。\n",
"# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n",
"# 'input_ids': torch.LongTensor([batch_size, max_len])\n",
"# 'input_len': torch.LongTensor([batch_size])\n",
"# 'first': torch.LongTensor([batch_size, max_len'])\n",
"# 'first_len': torch.LongTensor([batch_size])\n",
"# 'target': torch.LongTensor([batch_size, max_len'-2])\n",
"# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理"
]
},
{
"cell_type": "markdown",
"id": "3583df6d",
"metadata": {},
"source": [
"### 4. 模型准备\n",
"传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n",
"\n",
"此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n",
"```python\n",
"class Model(nn.Module):\n",
" def train_step(self, x, y):\n",
" return {'loss': (x-y).abs().mean()}\n",
"```\n",
"fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n",
"- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n",
"- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n",
"- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n",
"\n",
"``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f131c1a3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:48:21] </span><span style=\"color: #800000; text-decoration-color: #800000\">WARNING </span> Some weights of the model checkpoint at <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1490\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1490</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> bert-base-uncased were not used when initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel: <span style=\"font-weight: bold\">[</span><span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.decoder.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.weight'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.LayerNorm.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.predictions.transform.dense.bias'</span>, <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'cls.seq_relationship.bias'</span><span style=\"font-weight: bold\">]</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model trained <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> on another task or with another architecture <span style=\"font-weight: bold\">(</span>e.g. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> initializing a BertForSequenceClassification model <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> from a BertForPreTraining model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> - This IS NOT expected if you are initializing <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel from the checkpoint of a model that you <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> expect to be exactly identical <span style=\"font-weight: bold\">(</span>initializing a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model from a <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertForSequenceClassification model<span style=\"font-weight: bold\">)</span>. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> All the weights of BertModel were initialized from <a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">modeling_utils.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/transformers/torch/modeling_utils.py#1507\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1507</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the model checkpoint at bert-base-uncased. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> If your task is similar to the task the model of <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> the checkpoint was trained on, you can already use <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> BertModel for predictions without further <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> training. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from fastNLP.transformers.torch import BertModel\n",
"from fastNLP import seq_len_to_mask\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class BertNER(nn.Module):\n",
" def __init__(self, model_name, num_class, tag_vocab=None):\n",
" super().__init__()\n",
" self.bert = BertModel.from_pretrained(model_name)\n",
" self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(self.bert.config.hidden_size, num_class))\n",
" self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n",
" if tag_vocab is not None:\n",
" self._init_constrained_transition()\n",
" \n",
" def forward(self, input_ids, input_len, first):\n",
" attention_mask = seq_len_to_mask(input_len)\n",
" outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
" last_hidden_state = outputs.last_hidden_state\n",
" first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n",
" first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n",
" first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n",
" \n",
" pred = self.mlp(first_bpe_state)\n",
" return {'pred': pred}\n",
" \n",
" def train_step(self, input_ids, input_len, first, target):\n",
" pred = self(input_ids, input_len, first)['pred']\n",
" loss = F.cross_entropy(pred.transpose(1, 2), target)\n",
" return {'loss': loss}\n",
" \n",
" def evaluate_step(self, input_ids, input_len, first):\n",
" pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n",
" return {'pred': pred}\n",
" \n",
" def constrained_decode(self, input_ids, input_len, first, first_len):\n",
" # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n",
" # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n",
" pred = self(input_ids, input_len, first)['pred']\n",
" cons_pred = []\n",
" for _pred, _len in zip(pred, first_len):\n",
" _pred = _pred[:_len]\n",
" tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n",
" for i in range(1, _len):\n",
" tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n",
" cons_pred.append(torch.LongTensor(tags))\n",
" cons_pred = pad_sequence(cons_pred, batch_first=True)\n",
" return {'pred': cons_pred}\n",
" \n",
" def _init_constrained_transition(self):\n",
" from fastNLP.modules.torch import allowed_transitions\n",
" allowed_trans = allowed_transitions(self.tag_vocab)\n",
" transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n",
" for s, e in allowed_trans:\n",
" transition[s, e] = 0\n",
" self.register_buffer('transition', transition)\n",
"\n",
"model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))"
]
},
{
"cell_type": "markdown",
"id": "5aeee1e9",
"metadata": {},
"source": [
"### Trainer 的使用\n",
"fastNLP 的 Trainer 是用于对模型进行训练的部件。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f4250f0b",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:49:22] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Running evaluator sanity check for <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> batches. <a href=\"file://../../fastNLP/core/controllers/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">trainer.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/controllers/trainer.py#661\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">661</span></a>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #00d75f; text-decoration-color: #00d75f\">+++++++++++++++++++++++++++++ </span><span style=\"font-weight: bold\">Eval. results on Epoch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span><span style=\"font-weight: bold\">, Batch:</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span><span style=\"color: #00d75f; text-decoration-color: #00d75f\"> +++++++++++++++++++++++++++++</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"f#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"pre#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>,\n",
" <span style=\"color: #000080; text-decoration-color: #000080; font-weight: bold\">\"rec#f\"</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span>\n",
"<span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\n",
" \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n",
" \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n",
" \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n",
"\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[10:51:15] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> The best performance for monitor f#<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">f:0</span>.<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">402447</span> was <a href=\"file://../../fastNLP/core/callbacks/progress_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">progress_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/progress_callback.py#37\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">37</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> achieved in Epoch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>, Global Batch:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">625</span>. The <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> evaluation result: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.447906</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.365365</span><span style=\"font-weight: bold\">}</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Loading best model from buffer with f#f: <a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">load_best_model_callback.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">115</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.402447</span><span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from torch import optim\n",
"from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n",
"from fastNLP import SpanFPreRecMetric\n",
"\n",
"optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n",
"callbacks = [\n",
" LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n",
" TorchWarmupCallback()\n",
"] \n",
"\n",
"trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n",
" evaluate_dataloaders=dls['dev'], \n",
" metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
" n_epochs=1, callbacks=callbacks, \n",
" # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n",
" evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n",
" device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n",
"trainer.run()"
]
},
{
"cell_type": "markdown",
"id": "c600a450",
"metadata": {},
"source": [
"### Evaluator的使用\n",
"fastNLP中用于评测数据的对象。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1b19f0ba",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'f#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.390326</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'pre#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.414741</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'rec#f'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.368626</span><span style=\"font-weight: bold\">}</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fastNLP import Evaluator\n",
"from fastNLP import SpanFPreRecMetric\n",
"\n",
"evaluator = Evaluator(model=model, dataloaders=dls['test'], \n",
" metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n",
" evaluate_input_mapping={'first_len': 'seq_len'}, \n",
" device=0)\n",
"evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52f87770",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f723fe399df34917875ad74c2542508c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n",
"def input_mapping(x):\n",
" x['seq_len'] = x['first_len']\n",
" return x\n",
"evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n",
" metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n",
" evaluate_fn='constrained_decode',\n",
" # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n",
" # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n",
" evaluate_input_mapping=input_mapping)\n",
"evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "419e718b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Loading…
Cancel
Save