From a862c31b71b0b7e6b9a3ce32eddfe48d5c58f7cc Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Thu, 7 Jul 2022 14:55:42 +0000 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20fastnlp=5Ftorch=5Ftutorial?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tutorials/fastnlp_torch_tutorial.ipynb | 869 ++++++++++++++++++ 1 file changed, 869 insertions(+) create mode 100644 docs/source/tutorials/fastnlp_torch_tutorial.ipynb diff --git a/docs/source/tutorials/fastnlp_torch_tutorial.ipynb b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb new file mode 100644 index 00000000..9633ac7f --- /dev/null +++ b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6011adf8", + "metadata": {}, + "source": [ + "# 10 分钟快速上手 fastNLP torch\n", + "\n", + "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e166c051", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n", + "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n", + "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n", + "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n", + " Issued certificate has expired.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 982975 (960K) [application/x-zip-compressed]\n", + "Saving to: ‘conll2003.zip’\n", + "\n", + "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n", + "\n", + "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n", + "\n", + "Archive: conll2003.zip\n", + " inflating: conll2003/metadata \n", + " inflating: conll2003/test.txt \n", + " inflating: conll2003/train.txt \n", + " inflating: conll2003/valid.txt \n" + ] + } + ], + "source": [ + "# Linux/Mac 下载数据,并解压\n", + "import platform\n", + "if platform.system() != \"Windows\":\n", + " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n", + " !unzip conll2003.zip -d conll2003\n", + "# Windows用户请通过复制该url到浏览器下载该数据并解压" + ] + }, + { + "cell_type": "markdown", + "id": "f7acbf1f", + "metadata": {}, + "source": [ + "## 目录\n", + "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n", + "- 1. 数据加载\n", + "- 2. 数据预处理、数据缓存\n", + "- 3. DataLoader\n", + "- 4. 模型准备\n", + "- 5. Trainer的使用\n", + "- 6. Evaluator的使用\n", + "- 7. 其它【待补充】\n", + " - 7.1 使用多卡进行训练、评测\n", + " - 7.2 使用ZeRO优化\n", + " - 7.3 通过overfit测试快速验证模型\n", + " - 7.4 复杂Monitor的使用\n", + " - 7.5 训练过程中,使用不同的测试函数\n", + " - 7.6 更有效率的Sampler\n", + " - 7.7 保存模型\n", + " - 7.8 断点重训\n", + " - 7.9 使用huggingface datasets\n", + " - 7.10 使用torchmetrics来作为metric\n", + " - 7.11 将预测结果写出到文件\n", + " - 7.12 混合 dataset 训练\n", + " - 7.13 logger的使用\n", + " - 7.14 自定义分布式 Metric 。\n", + " - 7.15 通过batch_step_fn实现R-Drop" + ] + }, + { + "cell_type": "markdown", + "id": "0657dfba", + "metadata": {}, + "source": [ + "#### 1. 数据加载\n", + "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c557f0ba", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6f59e438", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 3 datasets:\n", + "\ttrain has 14987 instances.\n", + "\ttest has 3684 instances.\n", + "\tdev has 3466 instances.\n", + "\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet, Instance\n", + "from fastNLP.io import Loader\n", + "\n", + "\n", + "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n", + "class ConllLoader(Loader):\n", + " def _load(self, path):\n", + " ds = DataSet()\n", + " with open(path, 'r') as f:\n", + " segments = []\n", + " for line in f:\n", + " line = line.strip()\n", + " if line == '': # 如果为空行,说明需要切换到下一句了。\n", + " if segments:\n", + " raw_words = [s[0] for s in segments]\n", + " raw_target = [s[1] for s in segments]\n", + " # 将一个 sample 插入到 DataSet中\n", + " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n", + " segments = []\n", + " else:\n", + " parts = line.split()\n", + " assert len(parts)==4\n", + " segments.append([parts[0], parts[-1]])\n", + " return ds\n", + " \n", + "\n", + "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n", + "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n", + "data_bundle = ConllLoader().load({\n", + " 'train': 'conll2003/train.txt',\n", + " 'test': 'conll2003/test.txt',\n", + " 'dev': 'conll2003/valid.txt'\n", + "})\n", + "\"\"\"\n", + "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n", + "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n", + "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n", + "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n", + "\"\"\"\n", + "\n", + "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n", + "# data_bundle.get_dataset('train') # 可以获取单个 dataset" + ] + }, + { + "cell_type": "markdown", + "id": "57ae314d", + "metadata": {}, + "source": [ + "#### 2. 数据预处理\n", + "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n", + "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n", + "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96389988", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "IOPub message rate exceeded.\n",
+      "The notebook server will temporarily stop sending output\n",
+      "to the client in order to avoid crashing it.\n",
+      "To change this limit, set the config variable\n",
+      "`--NotebookApp.iopub_msg_rate_limit`.\n",
+      "\n",
+      "Current values:\n",
+      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
+      "NotebookApp.rate_limit_window=3.0 (secs)\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
[10:48:13] INFO     Save cache to /remote-home/hyan01/exps/fastNLP/fastN cache_results.py:332\n",
+       "                    LP/demo/torch_tutorial/caches/c7f74559_cache.pkl.                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n", + "from fastNLP.transformers.torch import BertTokenizer\n", + "from fastNLP import cache_results, Vocabulary\n", + "\n", + "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n", + "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n", + "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n", + "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n", + "@cache_results('caches/cache.pkl')\n", + "def process_data(data_bundle, model_name):\n", + " tokenizer = BertTokenizer.from_pretrained(model_name)\n", + " def bpe(raw_words):\n", + " bpes = [tokenizer.cls_token_id]\n", + " first = [0]\n", + " first_index = 1 # 记录第一个bpe的位置\n", + " for word in raw_words:\n", + " bpe = tokenizer.encode(word, add_special_tokens=False)\n", + " bpes.extend(bpe)\n", + " first.append(first_index)\n", + " first_index += len(bpe)\n", + " bpes.append(tokenizer.sep_token_id)\n", + " first.append(first_index)\n", + " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n", + " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n", + " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n", + " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n", + " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n", + " # 更多信息可以参考对应的文档。\n", + " \n", + " # tag的词表,由于这是词表,所以不需要有padding和unk\n", + " tag_vocab = Vocabulary(padding=None, unknown=None)\n", + " # 从 train 数据的 raw_target 中获取建立词表\n", + " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n", + " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n", + " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n", + " \n", + " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n", + " data_bundle.set_vocab(tag_vocab, field_name='target')\n", + " \n", + " return data_bundle, tokenizer\n", + "\n", + "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n", + "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 " + ] + }, + { + "cell_type": "markdown", + "id": "80036fcd", + "metadata": {}, + "source": [ + "### 3. DataLoader \n", + "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "09494695", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_dataloader\n", + "\n", + "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n", + "# fastNLP.TorchDataLoader 对象。\n", + "dls = prepare_dataloader(data_bundle, batch_size=24) \n", + "\n", + "\n", + "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n", + "# 默认使用 0 进行 pad 。\n", + "for dl in dls.values():\n", + " # 可以通过 set_pad 修改 padding 的行为。\n", + " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", + " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n", + " dl.set_ignore('raw_target')\n", + " dl.set_pad('target', pad_val=-100)\n", + "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n", + "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n", + "# DataSet 也支持这两个方法。\n", + "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n", + "# 'input_ids': torch.LongTensor([batch_size, max_len])\n", + "# 'input_len': torch.LongTensor([batch_size])\n", + "# 'first': torch.LongTensor([batch_size, max_len'])\n", + "# 'first_len': torch.LongTensor([batch_size])\n", + "# 'target': torch.LongTensor([batch_size, max_len'-2])\n", + "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理" + ] + }, + { + "cell_type": "markdown", + "id": "3583df6d", + "metadata": {}, + "source": [ + "### 4. 模型准备\n", + "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n", + "\n", + "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n", + "```python\n", + "class Model(nn.Module):\n", + " def train_step(self, x, y):\n", + " return {'loss': (x-y).abs().mean()}\n", + "```\n", + "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n", + "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n", + "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n", + "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n", + "\n", + "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f131c1a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:48:21] WARNING  Some weights of the model checkpoint at            modeling_utils.py:1490\n",
+       "                    bert-base-uncased were not used when initializing                        \n",
+       "                    BertModel: ['cls.predictions.bias',                                      \n",
+       "                    'cls.predictions.transform.LayerNorm.weight',                            \n",
+       "                    'cls.seq_relationship.weight',                                           \n",
+       "                    'cls.predictions.decoder.weight',                                        \n",
+       "                    'cls.predictions.transform.dense.weight',                                \n",
+       "                    'cls.predictions.transform.LayerNorm.bias',                              \n",
+       "                    'cls.predictions.transform.dense.bias',                                  \n",
+       "                    'cls.seq_relationship.bias']                                             \n",
+       "                    - This IS expected if you are initializing                               \n",
+       "                    BertModel from the checkpoint of a model trained                         \n",
+       "                    on another task or with another architecture (e.g.                       \n",
+       "                    initializing a BertForSequenceClassification model                       \n",
+       "                    from a BertForPreTraining model).                                        \n",
+       "                    - This IS NOT expected if you are initializing                           \n",
+       "                    BertModel from the checkpoint of a model that you                        \n",
+       "                    expect to be exactly identical (initializing a                           \n",
+       "                    BertForSequenceClassification model from a                               \n",
+       "                    BertForSequenceClassification model).                                    \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     All the weights of BertModel were initialized from modeling_utils.py:1507\n",
+       "                    the model checkpoint at bert-base-uncased.                               \n",
+       "                    If your task is similar to the task the model of                         \n",
+       "                    the checkpoint was trained on, you can already use                       \n",
+       "                    BertModel for predictions without further                                \n",
+       "                    training.                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "from fastNLP.transformers.torch import BertModel\n", + "from fastNLP import seq_len_to_mask\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class BertNER(nn.Module):\n", + " def __init__(self, model_name, num_class, tag_vocab=None):\n", + " super().__init__()\n", + " self.bert = BertModel.from_pretrained(model_name)\n", + " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n", + " nn.Dropout(0.3),\n", + " nn.Linear(self.bert.config.hidden_size, num_class))\n", + " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n", + " if tag_vocab is not None:\n", + " self._init_constrained_transition()\n", + " \n", + " def forward(self, input_ids, input_len, first):\n", + " attention_mask = seq_len_to_mask(input_len)\n", + " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", + " last_hidden_state = outputs.last_hidden_state\n", + " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n", + " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n", + " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n", + " \n", + " pred = self.mlp(first_bpe_state)\n", + " return {'pred': pred}\n", + " \n", + " def train_step(self, input_ids, input_len, first, target):\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " loss = F.cross_entropy(pred.transpose(1, 2), target)\n", + " return {'loss': loss}\n", + " \n", + " def evaluate_step(self, input_ids, input_len, first):\n", + " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n", + " return {'pred': pred}\n", + " \n", + " def constrained_decode(self, input_ids, input_len, first, first_len):\n", + " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n", + " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " cons_pred = []\n", + " for _pred, _len in zip(pred, first_len):\n", + " _pred = _pred[:_len]\n", + " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n", + " for i in range(1, _len):\n", + " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n", + " cons_pred.append(torch.LongTensor(tags))\n", + " cons_pred = pad_sequence(cons_pred, batch_first=True)\n", + " return {'pred': cons_pred}\n", + " \n", + " def _init_constrained_transition(self):\n", + " from fastNLP.modules.torch import allowed_transitions\n", + " allowed_trans = allowed_transitions(self.tag_vocab)\n", + " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n", + " for s, e in allowed_trans:\n", + " transition[s, e] = 0\n", + " self.register_buffer('transition', transition)\n", + "\n", + "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))" + ] + }, + { + "cell_type": "markdown", + "id": "5aeee1e9", + "metadata": {}, + "source": [ + "### Trainer 的使用\n", + "fastNLP 的 Trainer 是用于对模型进行训练的部件。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f4250f0b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:49:22] INFO     Running evaluator sanity check for 2 batches.              trainer.py:661\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
+++++++++++++++++++++++++++++ Eval. results on Epoch:1, Batch:0 +++++++++++++++++++++++++++++\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#f\": 0.402447,\n",
+       "  \"pre#f\": 0.447906,\n",
+       "  \"rec#f\": 0.365365\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n", + " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n", + " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10:51:15] INFO     The best performance for monitor f#f:0.402447 was progress_callback.py:37\n",
+       "                    achieved in Epoch:1, Global Batch:625. The                               \n",
+       "                    evaluation result:                                                       \n",
+       "                    {'f#f': 0.402447, 'pre#f': 0.447906, 'rec#f':                            \n",
+       "                    0.365365}                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     Loading best model from buffer with f#f:  load_best_model_callback.py:115\n",
+       "                    0.402447...                                                              \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from torch import optim\n", + "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n", + "callbacks = [\n", + " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n", + " TorchWarmupCallback()\n", + "] \n", + "\n", + "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n", + " evaluate_dataloaders=dls['dev'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " n_epochs=1, callbacks=callbacks, \n", + " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n", + " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n", + " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c600a450", + "metadata": {}, + "source": [ + "### Evaluator的使用\n", + "fastNLP中用于评测数据的对象。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b19f0ba", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " evaluate_input_mapping={'first_len': 'seq_len'}, \n", + " device=0)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52f87770", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f723fe399df34917875ad74c2542508c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n", + "def input_mapping(x):\n", + " x['seq_len'] = x['first_len']\n", + " return x\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n", + " evaluate_fn='constrained_decode',\n", + " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n", + " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n", + " evaluate_input_mapping=input_mapping)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419e718b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}