From 19a48c7101d8c95ade89bdf09ad907ee13d7e78b Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Fri, 27 May 2022 22:47:11 +0800 Subject: [PATCH] update example-12 lxr 220527 --- tutorials/fastnlp_tutorial_e1.ipynb | 11 +- tutorials/fastnlp_tutorial_e2.ipynb | 815 +++++++--------------------- 2 files changed, 218 insertions(+), 608 deletions(-) diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb index 92a49925..628dd7ae 100644 --- a/tutorials/fastnlp_tutorial_e1.ipynb +++ b/tutorials/fastnlp_tutorial_e1.ipynb @@ -233,7 +233,7 @@ } ], "source": [ - "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", + "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", "\n", "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", "\n", @@ -881,6 +881,15 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb index 8e734f01..9185102f 100644 --- a/tutorials/fastnlp_tutorial_e2.ipynb +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# E2. 使用 PrefixTuning 完成 SST2 分类" + "# E2. 使用 continuous prompt 完成 SST2 分类" ] }, { @@ -35,10 +35,12 @@ ], "source": [ "import torch\n", - "import torch.nn as nn\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", + "import torch.nn as nn\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "\n", "import transformers\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForSequenceClassification\n", @@ -69,180 +71,226 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", - "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "253d79d7a67e4dc88338448b5bcb3fb9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00[21:00:11] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", - "\n" - ], - "text/plain": [ - "\u001b[2;36m[21:00:11]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=22992;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=669026;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
@@ -374,12 +383,9 @@
     {
      "data": {
       "text/html": [
-       "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
@@ -387,33 +393,32 @@
     {
      "data": {
       "text/html": [
-       "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
+       "
\n",
        "
\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "\n" ] }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "
\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
@@ -421,439 +426,26 @@
     {
      "data": {
       "text/html": [
-       "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
-       "
\n" + "
\n"
       ],
-      "text/plain": [
-       "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n"
-      ]
+      "text/plain": []
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "data": {
-      "text/html": [
-       "
{\n",
-       "  \"acc#acc\": 0.878125,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 281.0\n",
-       "}\n",
-       "
\n" - ], "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" + "{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}" ] }, + "execution_count": 13, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.903125,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 289.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.903125\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m289.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.871875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 279.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.871875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m279.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.890625,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 285.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 280.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 284.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.8875,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 284.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
-       "
\n" - ], - "text/plain": [ - "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
{\n",
-       "  \"acc#acc\": 0.890625,\n",
-       "  \"total#acc\": 320.0,\n",
-       "  \"correct#acc\": 285.0\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m{\u001b[0m\n", - " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.890625\u001b[0m,\n", - " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", - " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m285.0\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], "source": [ - "trainer.run(num_eval_batch_per_dl=10)" + "trainer.evaluator.run()" ] }, { @@ -881,6 +473,15 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4,