diff --git a/tutorials/fastnlp_tutorial_1.ipynb b/tutorials/fastnlp_tutorial_1.ipynb index 09e8821d..db77e6c3 100644 --- a/tutorials/fastnlp_tutorial_1.ipynb +++ b/tutorials/fastnlp_tutorial_1.ipynb @@ -1325,7 +1325,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 8c3c935e..353e4645 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -288,7 +288,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" }, "pycharm": { "stem_cell": { diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb index 628dd7ae..6ec04cb4 100644 --- a/tutorials/fastnlp_tutorial_e1.ipynb +++ b/tutorials/fastnlp_tutorial_e1.ipynb @@ -4,7 +4,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# E1. 使用 DistilBert 完成 SST2 分类" + "  从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", + "\n", + "  每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n", + "\n", + "  1   基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n", + "\n", + "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", + "\n", + "  3   模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { @@ -48,22 +63,64 @@ "\n", "import fastNLP\n", "from fastNLP import Trainer\n", - "from fastNLP.core.utils.utils import dataclass_to_dict\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n", + "\n", + "  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n", + "\n", + "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n", + "\n", + "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", + "\n", + "  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + "    **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n", + "\n", + "    **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + "    **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + "    **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + "    此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + "  数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", "\n", - "task = \"sst2\"\n", - "model_checkpoint = \"distilbert-base-uncased\"" + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n", + "\n", + "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" ] }, { @@ -84,7 +141,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", + "model_id": "adc9449171454f658285f220b70126e1", "version_major": 2, "version_minor": 0 }, @@ -97,9 +154,16 @@ } ], "source": [ - "from datasets import load_dataset, load_metric\n", + "from datasets import load_dataset\n", "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + "dataset = load_dataset('glue', task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "  加载之后,根据`GLUE`中`SST2`数据集的格式,尝试打印部分数据,检查加载结果" ] }, { @@ -111,62 +175,89 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + "Sentence: hide new secretions from the parental units \n" ] } ], "source": [ - "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "task_to_keys = {\n", + " 'cola': ('sentence', None),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mrpc': ('sentence1', 'sentence2'),\n", + " 'qnli': ('question', 'sentence'),\n", + " 'qqp': ('question1', 'question2'),\n", + " 'rte': ('sentence1', 'sentence2'),\n", + " 'sst2': ('sentence', None),\n", + " 'stsb': ('sentence1', 'sentence2'),\n", + " 'wnli': ('sentence1', 'sentence2'),\n", + "}\n", "\n", - "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + "sentence1_key, sentence2_key = task_to_keys[task]\n", + "\n", + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" ] }, { - "cell_type": "code", - "execution_count": 5, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "task_to_keys = {\n", - " \"cola\": (\"sentence\", None),\n", - " \"mnli\": (\"premise\", \"hypothesis\"),\n", - " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", - " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", - " \"qnli\": (\"question\", \"sentence\"),\n", - " \"qqp\": (\"question1\", \"question2\"),\n", - " \"rte\": (\"sentence1\", \"sentence2\"),\n", - " \"sst2\": (\"sentence\", None),\n", - " \"stsb\": (\"sentence1\", \"sentence2\"),\n", - " \"wnli\": (\"sentence1\", \"sentence2\"),\n", - "}\n", + "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n", + "\n", + "  接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n", + "\n", + "    定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n", + "\n", + "此处的`tokenizer`和`SequenceClassificationModel`都是基于**`distilbert-base-uncased`模型**\n", "\n", - "sentence1_key, sentence2_key = task_to_keys[task]" + "  即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n", + "\n", + "  模型包含1个编码层、6个自注意力层,详解见本篇末尾,更多细节请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n", + "\n", + "首先,通过从`transformers`库中导入`AutoTokenizer`模块,使用`from_pretrained`函数初始化\n", + "\n", + "  此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n", + "\n", + "  需要注意的是,处理后返回的两个键值,`'input_ids'`表示原始文本对应的词素编号序列\n", + "\n", + "    `'attention_mask'`表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Sentence: hide new secretions from the parental units \n" + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ - "if sentence2_key is None:\n", - " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", - "else:\n", - " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", - " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + "model_checkpoint = 'distilbert-base-uncased'\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,定义预处理函数,**通过`dataset.map`方法**,**将数据集中的文本**,**替换为词素编号序列**" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -189,66 +280,27 @@ ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "class ClassModel(nn.Module):\n", - " def __init__(self, num_labels, model_checkpoint):\n", - " nn.Module.__init__(self)\n", - " self.num_labels = num_labels\n", - " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", - " num_labels=num_labels)\n", - " self.loss_fn = nn.CrossEntropyLoss()\n", - "\n", - " def forward(self, input_ids, attention_mask):\n", - " return self.back_bone(input_ids, attention_mask)\n", + "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n", "\n", - " def train_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " return {\"loss\": self.loss_fn(pred, labels)}\n", - "\n", - " def evaluate_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " pred = torch.max(pred, dim=-1)[1]\n", - " return {\"pred\": pred, \"target\": labels}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n", - "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", + "  其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "  例如,`'label'`是`SST2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + "    `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "class TestDistilBertDataset(Dataset):\n", + "class SeqClsDataset(Dataset):\n", " def __init__(self, dataset):\n", - " super(TestDistilBertDataset, self).__init__()\n", + " Dataset.__init__(self)\n", " self.dataset = dataset\n", "\n", " def __len__(self):\n", @@ -256,16 +308,27 @@ "\n", " def __getitem__(self, item):\n", " item = self.dataset[item]\n", - " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] " + " return item['input_ids'], item['attention_mask'], [item['label']] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,**定义校对函数`collate_fn`对齐同个`batch`内的每笔数据**,需要注意的是该函数的\n", + "\n", + "  **返回值必须是字典**,**键值必须同待训练模型的`train_step`和`evaluate_step`函数的参数**\n", + "\n", + "  **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v0.8`的第一条**参数匹配**机制" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "def test_bert_collate_fn(batch):\n", + "def collate_fn(batch):\n", " input_ids, atten_mask, labels = [], [], []\n", " max_length = [0] * 3\n", " for each_item in batch:\n", @@ -280,35 +343,136 @@ " each = (input_ids, atten_mask, labels)[i]\n", " for item in each:\n", " item.extend([0] * (max_length[i] - len(item)))\n", - " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", - " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", - " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}" + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n", + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", "dataloader_train = DataLoader(dataset=dataset_train, \n", - " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n", - "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", "dataloader_valid = DataLoader(dataset=dataset_valid, \n", - " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)" + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n", + "\n", + "  最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n", + "\n", + "    初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n", + "\n", + "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n", + "\n", + "  导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n", + "\n", + "需要注意的是**`AutoModelForSequenceClassification`模块的输入参数和输出结构**\n", + "\n", + "  一方面,可以**通过输入标签值`labels`**,**使用模块内的损失函数计算损失`loss`**\n", + "\n", + "    并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n", + "\n", + "  另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率`logits`**\n", + "\n", + "    基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n", + "\n", + "    同样需要注意,函数的返回值体现了`fastNLP v0.8`的第二条**参数匹配**机制" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " output = self.back_bone(input_ids=input_ids, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return output\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n", + "\n", + "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", - " device='cuda',\n", + " device=1, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=dataloader_train,\n", @@ -318,42 +482,35 @@ ] }, { - "cell_type": "code", - "execution_count": 14, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# help(model.back_bone.forward)" + "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n", + "\n", + "  `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
[21:00:11] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
@@ -370,16 +527,23 @@
      },
      "metadata": {},
      "output_type": "display_data"
-    },
+    }
+   ],
+   "source": [
+    "trainer.run(num_eval_batch_per_dl=10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
     {
      "data": {
       "text/html": [
-       "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
@@ -387,473 +551,155 @@
     {
      "data": {
       "text/html": [
-       "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "\u001b[1m{\u001b[0m\n",
-       "  \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n",
-       "  \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n",
-       "  \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n",
-       "\u001b[1m}\u001b[0m\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.878125,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 281.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.903125,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 289.0\n",
-       "}\n",
-       "
\n" - ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}" ] }, + "execution_count": 14, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.890625,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 285.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 280.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 284.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 284.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.890625,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 285.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 附:`DistilBertForSequenceClassification`模块结构\n", + "\n", + "```\n", + "\n", + "```" ] }, { diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb index 1d7746be..93143090 100644 --- a/tutorials/fastnlp_tutorial_e2.ipynb +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -4,7 +4,52 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# E2. 使用 continuous prompt 完成 SST2 分类" + "# E2. 使用 Bert + prompt 完成 SST2 分类\n", + "\n", + "  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n", + "\n", + "  2   准备工作:`P-Tuning v2`原理概述、`P-Tuning v2`模型搭建\n", + "\n", + "  3   模型训练:加载`tokenizer`、预处理`dataset`、模型训练与分析" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n", + "\n", + "  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`prompt-based tuning`方式\n", + "\n", + "    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n", + "\n", + "    将首先简单介绍提示学习模型的研究,以及与`fastNLP v0.8`结合的优势\n", + "\n", + "**`prompt`**,**提示词、提词器**,最早出自**`PET`**,\n", + "\n", + "  \n", + "\n", + "**`prompt-based tuning`**,**基于提示的微调**,描述\n", + "\n", + "  **`prompt-based model`**,**基于提示的模型**\n", + "\n", + "**`prompt-based model`**,**基于提示的模型**,举例\n", + "\n", + "  案例一:**`P-Tuning v1`**\n", + "\n", + "  案例二:**`PromptTuning`**\n", + "\n", + "  案例三:**`PrefixTuning`**\n", + "\n", + "  案例四:**`SoftPrompt`**\n", + "\n", + "使用`fastNLP v0.8`实现`prompt-based model`的优势\n", + "\n", + "  \n", + "\n", + "  本示例仍使用了`tutorial-E1`的`SST2`数据集,将`bert-base-uncased`作为基础模型\n", + "\n", + "    在后续实现中,意图通过将连续的`prompt`与`model`拼接,解决`SST2`二分类任务" ] }, { @@ -35,11 +80,10 @@ ], "source": [ "import torch\n", + "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", - "import torch.nn as nn\n", - "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", @@ -51,19 +95,31 @@ "from fastNLP import Trainer\n", "from fastNLP.core.metrics import Accuracy\n", "\n", - "print(transformers.__version__)" + "print(transformers.__version__)\n", + "\n", + "task = 'sst2'\n", + "model_checkpoint = 'bert-base-uncased'" ] }, { - "cell_type": "code", - "execution_count": 2, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", + "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n", "\n", - "task = \"sst2\"\n", - "model_checkpoint = \"distilbert-base-uncased\"" + "  本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v0.8`结合的案例\n", + "\n", + "    以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v0.8`的代码实践\n", + "\n", + "`P-Tuning v2`出自论文 [Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n", + "\n", + "  其主要贡献在于,在`PrefixTuning`等深度提示学习基础上,提升了其在分类标注等`NLU`任务的表现\n", + "\n", + "    并使之在中等规模模型,主要是参数量在`100M-1B`区间的模型上,获得与全参数微调相同的效果\n", + "\n", + "  其结构如图所示,\n", + "\n", + "" ] }, { @@ -72,7 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "class ClassModel(nn.Module):\n", + "class SeqClsModel(nn.Module):\n", " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", " nn.Module.__init__(self)\n", " self.num_labels = num_labels\n", @@ -92,7 +148,7 @@ " prompts = self.prefix_encoder(prefix_tokens)\n", " return prompts\n", "\n", - " def forward(self, input_ids, attention_mask, labels):\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", " \n", " batch_size = input_ids.shape[0]\n", " raw_embedding = self.embeddings(input_ids)\n", @@ -107,39 +163,64 @@ " return outputs\n", "\n", " def train_step(self, input_ids, attention_mask, labels):\n", - " return {\"loss\": self(input_ids, attention_mask, labels).loss}\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", "\n", " def evaluate_step(self, input_ids, attention_mask, labels):\n", " pred = self(input_ids, attention_mask, labels).logits\n", " pred = torch.max(pred, dim=-1)[1]\n", - " return {\"pred\": pred, \"target\": labels}" + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n", + "\n", + "  根据`P-Tuning v2`论文:*Generally, simple classification tasks prefer shorter prompts (less than 20)*\n", + "\n", + "  此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n", - "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n", + "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", + "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n", + "  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST2`数据集\n", "\n", - "# Generally, simple classification tasks prefer shorter prompts (less than 20)\n", + "    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-3)" + "    数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n", + "\n", + "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n", + "\n", + "  构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列" ] }, { @@ -153,14 +234,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1b73650d43f245ac8a5501dc91c6fe8c", + "model_id": "b72eeebd34354a88a99b2e07ec9a86df", "version_major": 2, "version_minor": 0 }, @@ -175,7 +255,7 @@ "source": [ "from datasets import load_dataset, load_metric\n", "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n", + "dataset = load_dataset('glue', task)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" ] @@ -189,14 +269,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", - "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n" + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-18ec0e709f05e61e.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e2f02ee7442ad73e.arrow\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0be84915c90f460896b8e67299e09df4", + "model_id": "d15505d825b34f649b719f1ff0d56114", "version_major": 2, "version_minor": 0 }, @@ -215,15 +295,26 @@ "encoded_dataset = dataset.map(preprocess_function, batched=True)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,定义`SeqClsDataset`类、定义校对函数`collate_fn`,这里沿用`tutorial-E1`中的内容\n", + "\n", + "  同样需要注意/强调的是,**`__getitem__`函数的返回值必须和原始数据集中的属性对应**\n", + "\n", + "  **`collate_fn`函数的返回值必须和`train_step`和`evaluate_step`函数的参数匹配**" + ] + }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "class TestDistilBertDataset(Dataset):\n", + "class SeqClsDataset(Dataset):\n", " def __init__(self, dataset):\n", - " super(TestDistilBertDataset, self).__init__()\n", + " Dataset.__init__(self)\n", " self.dataset = dataset\n", "\n", " def __len__(self):\n", @@ -231,16 +322,9 @@ "\n", " def __getitem__(self, item):\n", " item = self.dataset[item]\n", - " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] " - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "def test_bert_collate_fn(batch):\n", + " return item['input_ids'], item['attention_mask'], [item['label']] \n", + "\n", + "def collate_fn(batch):\n", " input_ids, atten_mask, labels = [], [], []\n", " max_length = [0] * 3\n", " for each_item in batch:\n", @@ -255,9 +339,16 @@ " each = (input_ids, atten_mask, labels)[i]\n", " for item in each:\n", " item.extend([0] * (max_length[i] - len(item)))\n", - " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", - " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", - " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}" + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" ] }, { @@ -266,25 +357,43 @@ "metadata": {}, "outputs": [], "source": [ - "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n", + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", "dataloader_train = DataLoader(dataset=dataset_train, \n", - " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n", - "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", "dataloader_valid = DataLoader(dataset=dataset_valid, \n", - " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)" + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], "source": [ "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", - " device='cuda',\n", - " n_epochs=10,\n", + " device=[0, 1],\n", + " n_epochs=20,\n", " optimizers=optimizers,\n", " train_dataloader=dataloader_train,\n", " evaluate_dataloaders=dataloader_valid,\n", @@ -292,85 +401,34 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}"
-      ]
-     },
-     "execution_count": 20,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "trainer.evaluator.run()"
    ]
diff --git a/tutorials/figures/E1-fig-glue-benchmark.png b/tutorials/figures/E1-fig-glue-benchmark.png
new file mode 100644
index 00000000..515db700
Binary files /dev/null and b/tutorials/figures/E1-fig-glue-benchmark.png differ
diff --git a/tutorials/figures/E2-fig-p-tuning-v2-model.png b/tutorials/figures/E2-fig-p-tuning-v2-model.png
new file mode 100644
index 00000000..b5a9c1b8
Binary files /dev/null and b/tutorials/figures/E2-fig-p-tuning-v2-model.png differ