From 8f211ef8ab6a84b8e734312d075628635e20d546 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Thu, 26 May 2022 21:57:44 +0800 Subject: [PATCH] add example-12 lxr 220526 --- tutorials/fastnlp_tutorial_0.ipynb | 13 +- tutorials/fastnlp_tutorial_e1.ipynb | 888 ++++++++++++++++++++++++++++ tutorials/fastnlp_tutorial_e2.ipynb | 888 ++++++++++++++++++++++++++++ 3 files changed, 1784 insertions(+), 5 deletions(-) create mode 100644 tutorials/fastnlp_tutorial_e1.ipynb create mode 100644 tutorials/fastnlp_tutorial_e2.ipynb diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb index 4e4ce55e..245eaf91 100644 --- a/tutorials/fastnlp_tutorial_0.ipynb +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -464,6 +464,9 @@ } ], "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", "from fastNLP import Trainer\n", "\n", "trainer = Trainer(\n", @@ -613,11 +616,11 @@ { "data": { "text/html": [ - "
{'acc#acc': 0.41, 'total#acc': 100.0, 'correct#acc': 41.0}\n", + "{'acc#acc': 0.37, 'total#acc': 100.0, 'correct#acc': 37.0}\n", "\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.41\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m41.0\u001b[0m\u001b[1m}\u001b[0m\n" + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.37\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m37.0\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, @@ -626,7 +629,7 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.41, 'total#acc': 100.0, 'correct#acc': 41.0}" + "{'acc#acc': 0.37, 'total#acc': 100.0, 'correct#acc': 37.0}" ] }, "execution_count": 9, @@ -756,7 +759,7 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.46, 'total#acc': 100.0, 'correct#acc': 46.0}" + "{'acc#acc': 0.47, 'total#acc': 100.0, 'correct#acc': 47.0}" ] }, "execution_count": 12, @@ -793,7 +796,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" }, "pycharm": { "stem_cell": { diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb new file mode 100644 index 00000000..92a49925 --- /dev/null +++ b/tutorials/fastnlp_tutorial_e1.ipynb @@ -0,0 +1,888 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E1. 使用 DistilBert 完成 SST2 分类" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.utils.utils import dataclass_to_dict\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", + "\n", + "task = \"sst2\"\n", + "model_checkpoint = \"distilbert-base-uncased\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset, load_metric\n", + "\n", + "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "task_to_keys = {\n", + " \"cola\": (\"sentence\", None),\n", + " \"mnli\": (\"premise\", \"hypothesis\"),\n", + " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", + " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", + " \"qnli\": (\"question\", \"sentence\"),\n", + " \"qqp\": (\"question1\", \"question2\"),\n", + " \"rte\": (\"sentence1\", \"sentence2\"),\n", + " \"sst2\": (\"sentence\", None),\n", + " \"stsb\": (\"sentence1\", \"sentence2\"),\n", + " \"wnli\": (\"sentence1\", \"sentence2\"),\n", + "}\n", + "\n", + "sentence1_key, sentence2_key = task_to_keys[task]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: hide new secretions from the parental units \n" + ] + } + ], + "source": [ + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " if sentence2_key is None:\n", + " return tokenizer(examples[sentence1_key], truncation=True)\n", + " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class ClassModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, input_ids, attention_mask):\n", + " return self.back_bone(input_ids, attention_mask)\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask).logits\n", + " return {\"loss\": self.loss_fn(pred, labels)}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": labels}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", + "\n", + "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class TestDistilBertDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " super(TestDistilBertDataset, self).__init__()\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def test_bert_collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n", + "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device='cuda',\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# help(model.back_bone.forward)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.903125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 289.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.890625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 285.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 280.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.890625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 285.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb new file mode 100644 index 00000000..8e734f01 --- /dev/null +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -0,0 +1,888 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E2. 使用 PrefixTuning 完成 SST2 分类" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.utils.utils import dataclass_to_dict\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", + "\n", + "task = \"sst2\"\n", + "model_checkpoint = \"distilbert-base-uncased\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset, load_metric\n", + "\n", + "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "task_to_keys = {\n", + " \"cola\": (\"sentence\", None),\n", + " \"mnli\": (\"premise\", \"hypothesis\"),\n", + " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", + " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", + " \"qnli\": (\"question\", \"sentence\"),\n", + " \"qqp\": (\"question1\", \"question2\"),\n", + " \"rte\": (\"sentence1\", \"sentence2\"),\n", + " \"sst2\": (\"sentence\", None),\n", + " \"stsb\": (\"sentence1\", \"sentence2\"),\n", + " \"wnli\": (\"sentence1\", \"sentence2\"),\n", + "}\n", + "\n", + "sentence1_key, sentence2_key = task_to_keys[task]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: hide new secretions from the parental units \n" + ] + } + ], + "source": [ + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " if sentence2_key is None:\n", + " return tokenizer(examples[sentence1_key], truncation=True)\n", + " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class ClassModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, input_ids, attention_mask):\n", + " return self.back_bone(input_ids, attention_mask)\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask).logits\n", + " return {\"loss\": self.loss_fn(pred, labels)}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": labels}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", + "\n", + "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class TestDistilBertDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " super(TestDistilBertDataset, self).__init__()\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def test_bert_collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n", + "dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device='cuda',\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# help(model.back_bone.forward)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.903125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 289.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.871875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 279.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.890625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 285.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 280.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "{\n", + " \"acc#acc\": 0.890625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 285.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}