From ec76ba8887f3e3df9778ae3db36bc5320ec52f62 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Sat, 28 May 2022 17:06:57 +0800 Subject: [PATCH] update example-2 lxr 220528 --- tutorials/fastnlp_tutorial_e2.ipynb | 181 +++++++++------------------- 1 file changed, 54 insertions(+), 127 deletions(-) diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb index 9185102f..1d7746be 100644 --- a/tutorials/fastnlp_tutorial_e2.ipynb +++ b/tutorials/fastnlp_tutorial_e2.ipynb @@ -39,7 +39,6 @@ "from torch.utils.data import DataLoader, Dataset\n", "\n", "import torch.nn as nn\n", - "from torch.nn.utils.rnn import pad_sequence\n", "\n", "import transformers\n", "from transformers import AutoTokenizer\n", @@ -50,7 +49,6 @@ "\n", "import fastNLP\n", "from fastNLP import Trainer\n", - "from fastNLP.core.utils.utils import dataclass_to_dict\n", "from fastNLP.core.metrics import Accuracy\n", "\n", "print(transformers.__version__)" @@ -73,134 +71,80 @@ "execution_count": 3, "metadata": {}, "outputs": [], - "source": [ - "class PromptEncoder(nn.Module):\n", - " def __init__(self, template, hidden_size):\n", - " nn.Module.__init__(self)\n", - " self.template = template\n", - " self.hidden_size = hidden_size\n", - " self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n", - " self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n", - "\n", - " self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n", - " # embed\n", - " self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n", - " # LSTM\n", - " self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n", - " hidden_size=hidden_size // 2,\n", - " num_layers=2, dropout=0.0,\n", - " bidirectional=True, batch_first=True)\n", - " # MLP\n", - " self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_size, hidden_size))\n", - " print(\"init prompt encoder...\")\n", - "\n", - " def forward(self, device):\n", - " input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n", - " output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n", - " return output_embeds" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], "source": [ "class ClassModel(nn.Module):\n", - " def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", " nn.Module.__init__(self)\n", - " self.template = template\n", " self.num_labels = num_labels\n", - " self.spell_length = sum(template)\n", - " self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", " for param in self.back_bone.parameters():\n", " param.requires_grad = False\n", - " self.embeddings = self.back_bone.get_input_embeddings()\n", - " \n", - " self.hidden_size = self.embeddings.embedding_dim\n", - " self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n", - " self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n", - " self.pad_token_id = self.tokenizer.pad_token_id\n", " \n", - " self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n", - "\n", - " self.loss_fn = nn.CrossEntropyLoss()\n", - "\n", - " def get_query(self, query):\n", - " device = query.device\n", - " return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n", - " torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n", - " torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n", - " torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n", - " query, \n", - " torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", "\n", - " def forward(self, input_ids):\n", - " input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n", - " attention_mask = input_ids != self.pad_token_id\n", + " def forward(self, input_ids, attention_mask, labels):\n", " \n", - " bz = input_ids.shape[0]\n", - " inputs_embeds = input_ids.clone()\n", - " inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n", - " inputs_embeds = self.embeddings(inputs_embeds)\n", - "\n", - " blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n", - " replace_embeds = self.prompt_encoder(input_ids.device)\n", - " for bidx in range(bz):\n", - " for i in range(self.spell_length):\n", - " inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", " \n", - " return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", "\n", " def train_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids).logits\n", - " return {\"loss\": self.loss_fn(pred, labels)}\n", + " return {\"loss\": self(input_ids, attention_mask, labels).loss}\n", "\n", " def evaluate_step(self, input_ids, attention_mask, labels):\n", - " pred = self(input_ids).logits\n", + " pred = self(input_ids, attention_mask, labels).logits\n", " pred = torch.max(pred, dim=-1)[1]\n", " return {\"pred\": pred, \"target\": labels}" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n", + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "init prompt encoder...\n" - ] } ], "source": [ "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", "\n", - "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n", + "\n", + "# Generally, simple classification tasks prefer shorter prompts (less than 20)\n", "\n", - "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "scrolled": false }, @@ -209,13 +153,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f82d2ccee863492582f94552654482f9", + "model_id": "1b73650d43f245ac8a5501dc91c6fe8c", "version_major": 2, "version_minor": 0 }, @@ -230,46 +175,28 @@ "source": [ "from datasets import load_dataset, load_metric\n", "\n", - "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" + "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cf324902e7b94ea9be709b979b425c96", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/68 [00:00