diff --git a/tutorials/fastnlp_tutorial_1.ipynb b/tutorials/fastnlp_tutorial_1.ipynb index 09e8821d..db77e6c3 100644 --- a/tutorials/fastnlp_tutorial_1.ipynb +++ b/tutorials/fastnlp_tutorial_1.ipynb @@ -1325,7 +1325,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 8c3c935e..353e4645 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -288,7 +288,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" }, "pycharm": { "stem_cell": { diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb index 628dd7ae..6ec04cb4 100644 --- a/tutorials/fastnlp_tutorial_e1.ipynb +++ b/tutorials/fastnlp_tutorial_e1.ipynb @@ -4,7 +4,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# E1. 使用 DistilBert 完成 SST2 分类" + " 从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n", + "\n", + " 每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n", + "\n", + " 1 基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n", + "\n", + " 2 准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", + "\n", + " 3 模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" ] }, { @@ -48,22 +63,64 @@ "\n", "import fastNLP\n", "from fastNLP import Trainer\n", - "from fastNLP.core.utils.utils import dataclass_to_dict\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n", + "\n", + " 本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n", + "\n", + " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n", + "\n", + "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n", + "\n", + " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + " **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n", + "\n", + " **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + " **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + " **`QNLI`/`RTE`/`WNLI`**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + " 此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**`SST`**,**全称`Stanford Sentiment Treebank`**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + " 数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", "\n", - "task = \"sst2\"\n", - "model_checkpoint = \"distilbert-base-uncased\"" + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n", + "\n", + " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" ] }, { @@ -84,7 +141,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", + "model_id": "adc9449171454f658285f220b70126e1", "version_major": 2, "version_minor": 0 }, @@ -97,9 +154,16 @@ } ], "source": [ - "from datasets import load_dataset, load_metric\n", + "from datasets import load_dataset\n", "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + "dataset = load_dataset('glue', task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 加载之后,根据`GLUE`中`SST2`数据集的格式,尝试打印部分数据,检查加载结果" ] }, { @@ -111,62 +175,89 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + "Sentence: hide new secretions from the parental units \n" ] } ], "source": [ - "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "task_to_keys = {\n", + " 'cola': ('sentence', None),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mrpc': ('sentence1', 'sentence2'),\n", + " 'qnli': ('question', 'sentence'),\n", + " 'qqp': ('question1', 'question2'),\n", + " 'rte': ('sentence1', 'sentence2'),\n", + " 'sst2': ('sentence', None),\n", + " 'stsb': ('sentence1', 'sentence2'),\n", + " 'wnli': ('sentence1', 'sentence2'),\n", + "}\n", "\n", - "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + "sentence1_key, sentence2_key = task_to_keys[task]\n", + "\n", + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" ] }, { - "cell_type": "code", - "execution_count": 5, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "task_to_keys = {\n", - " \"cola\": (\"sentence\", None),\n", - " \"mnli\": (\"premise\", \"hypothesis\"),\n", - " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", - " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", - " \"qnli\": (\"question\", \"sentence\"),\n", - " \"qqp\": (\"question1\", \"question2\"),\n", - " \"rte\": (\"sentence1\", \"sentence2\"),\n", - " \"sst2\": (\"sentence\", None),\n", - " \"stsb\": (\"sentence1\", \"sentence2\"),\n", - " \"wnli\": (\"sentence1\", \"sentence2\"),\n", - "}\n", + "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n", + "\n", + " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n", + "\n", + " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n", + "\n", + "此处的`tokenizer`和`SequenceClassificationModel`都是基于**`distilbert-base-uncased`模型**\n", "\n", - "sentence1_key, sentence2_key = task_to_keys[task]" + " 即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n", + "\n", + " 模型包含1个编码层、6个自注意力层,详解见本篇末尾,更多细节请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n", + "\n", + "首先,通过从`transformers`库中导入`AutoTokenizer`模块,使用`from_pretrained`函数初始化\n", + "\n", + " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n", + "\n", + " 需要注意的是,处理后返回的两个键值,`'input_ids'`表示原始文本对应的词素编号序列\n", + "\n", + " `'attention_mask'`表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Sentence: hide new secretions from the parental units \n" + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ - "if sentence2_key is None:\n", - " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", - "else:\n", - " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", - " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + "model_checkpoint = 'distilbert-base-uncased'\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,定义预处理函数,**通过`dataset.map`方法**,**将数据集中的文本**,**替换为词素编号序列**" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -189,66 +280,27 @@ ] }, { - "cell_type": "code", - "execution_count": 8, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "class ClassModel(nn.Module):\n", - " def __init__(self, num_labels, model_checkpoint):\n", - " nn.Module.__init__(self)\n", - " self.num_labels = num_labels\n", - " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", - " num_labels=num_labels)\n", - " self.loss_fn = nn.CrossEntropyLoss()\n", - "\n", - " def forward(self, input_ids, attention_mask):\n", - " return self.back_bone(input_ids, attention_mask)\n", + "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n", "\n", - " def train_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " return {\"loss\": self.loss_fn(pred, labels)}\n", - "\n", - " def evaluate_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids, attention_mask).logits\n", - " pred = torch.max(pred, dim=-1)[1]\n", - " return {\"pred\": pred, \"target\": labels}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n", - "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - } - ], - "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", + " 其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + " 例如,`'label'`是`SST2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "class TestDistilBertDataset(Dataset):\n", + "class SeqClsDataset(Dataset):\n", " def __init__(self, dataset):\n", - " super(TestDistilBertDataset, self).__init__()\n", + " Dataset.__init__(self)\n", " self.dataset = dataset\n", "\n", " def __len__(self):\n", @@ -256,16 +308,27 @@ "\n", " def __getitem__(self, item):\n", " item = self.dataset[item]\n", - " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] " + " return item['input_ids'], item['attention_mask'], [item['label']] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,**定义校对函数`collate_fn`对齐同个`batch`内的每笔数据**,需要注意的是该函数的\n", + "\n", + " **返回值必须是字典**,**键值必须同待训练模型的`train_step`和`evaluate_step`函数的参数**\n", + "\n", + " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v0.8`的第一条**参数匹配**机制" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "def test_bert_collate_fn(batch):\n", + "def collate_fn(batch):\n", " input_ids, atten_mask, labels = [], [], []\n", " max_length = [0] * 3\n", " for each_item in batch:\n", @@ -280,35 +343,136 @@ " each = (input_ids, atten_mask, labels)[i]\n", " for item in each:\n", " item.extend([0] * (max_length[i] - len(item)))\n", - " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", - " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", - " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}" + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n", + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", "dataloader_train = DataLoader(dataset=dataset_train, \n", - " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n", - "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", "dataloader_valid = DataLoader(dataset=dataset_valid, \n", - " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)" + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n", + "\n", + " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n", + "\n", + " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n", + "\n", + "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n", + "\n", + " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n", + "\n", + "需要注意的是**`AutoModelForSequenceClassification`模块的输入参数和输出结构**\n", + "\n", + " 一方面,可以**通过输入标签值`labels`**,**使用模块内的损失函数计算损失`loss`**\n", + "\n", + " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n", + "\n", + " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率`logits`**\n", + "\n", + " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n", + "\n", + " 同样需要注意,函数的返回值体现了`fastNLP v0.8`的第二条**参数匹配**机制" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " output = self.back_bone(input_ids=input_ids, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return output\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n", + "\n", + "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model=model,\n", " driver='torch',\n", - " device='cuda',\n", + " device=1, # 'cuda'\n", " n_epochs=10,\n", " optimizers=optimizers,\n", " train_dataloader=dataloader_train,\n", @@ -318,42 +482,35 @@ ] }, { - "cell_type": "code", - "execution_count": 14, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# help(model.back_bone.forward)" + "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n", + "\n", + " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" + "\n" ], - "text/plain": [ - "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] + "text/html": [ + "\n" + ], + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -370,16 +527,23 @@ }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", - "\n" + "\n" ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" @@ -387,473 +551,155 @@ { "data": { "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" + "\n" ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] + "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.878125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 281.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.903125,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 289.0\n", - "}\n", - "\n" - ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}" ] }, + "execution_count": 14, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.871875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 279.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 280.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.8875,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 284.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", - "\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n", - " \"acc#acc\": 0.890625,\n", - " \"total#acc\": 320.0,\n", - " \"correct#acc\": 285.0\n", - "}\n", - "\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 附:`DistilBertForSequenceClassification`模块结构\n", + "\n", + "```\n", + "
\n", - "\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " " + ] + }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trainer.evaluator.run()" ] diff --git a/tutorials/figures/E1-fig-glue-benchmark.png b/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 00000000..515db700 Binary files /dev/null and b/tutorials/figures/E1-fig-glue-benchmark.png differ diff --git a/tutorials/figures/E2-fig-p-tuning-v2-model.png b/tutorials/figures/E2-fig-p-tuning-v2-model.png new file mode 100644 index 00000000..b5a9c1b8 Binary files /dev/null and b/tutorials/figures/E2-fig-p-tuning-v2-model.png differ