| @@ -39,7 +39,6 @@ | |||
| "from torch.utils.data import DataLoader, Dataset\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "from torch.nn.utils.rnn import pad_sequence\n", | |||
| "\n", | |||
| "import transformers\n", | |||
| "from transformers import AutoTokenizer\n", | |||
| @@ -50,7 +49,6 @@ | |||
| "\n", | |||
| "import fastNLP\n", | |||
| "from fastNLP import Trainer\n", | |||
| "from fastNLP.core.utils.utils import dataclass_to_dict\n", | |||
| "from fastNLP.core.metrics import Accuracy\n", | |||
| "\n", | |||
| "print(transformers.__version__)" | |||
| @@ -73,134 +71,80 @@ | |||
| "execution_count": 3, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "class PromptEncoder(nn.Module):\n", | |||
| " def __init__(self, template, hidden_size):\n", | |||
| " nn.Module.__init__(self)\n", | |||
| " self.template = template\n", | |||
| " self.hidden_size = hidden_size\n", | |||
| " self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n", | |||
| " self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n", | |||
| "\n", | |||
| " self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n", | |||
| " # embed\n", | |||
| " self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n", | |||
| " # LSTM\n", | |||
| " self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n", | |||
| " hidden_size=hidden_size // 2,\n", | |||
| " num_layers=2, dropout=0.0,\n", | |||
| " bidirectional=True, batch_first=True)\n", | |||
| " # MLP\n", | |||
| " self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n", | |||
| " nn.ReLU(),\n", | |||
| " nn.Linear(hidden_size, hidden_size))\n", | |||
| " print(\"init prompt encoder...\")\n", | |||
| "\n", | |||
| " def forward(self, device):\n", | |||
| " input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n", | |||
| " output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n", | |||
| " return output_embeds" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "class ClassModel(nn.Module):\n", | |||
| " def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n", | |||
| " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", | |||
| " nn.Module.__init__(self)\n", | |||
| " self.template = template\n", | |||
| " self.num_labels = num_labels\n", | |||
| " self.spell_length = sum(template)\n", | |||
| " self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", | |||
| " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", | |||
| " num_labels=num_labels)\n", | |||
| " self.embeddings = self.back_bone.get_input_embeddings()\n", | |||
| "\n", | |||
| " for param in self.back_bone.parameters():\n", | |||
| " param.requires_grad = False\n", | |||
| " self.embeddings = self.back_bone.get_input_embeddings()\n", | |||
| " \n", | |||
| " self.hidden_size = self.embeddings.embedding_dim\n", | |||
| " self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n", | |||
| " self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n", | |||
| " self.pad_token_id = self.tokenizer.pad_token_id\n", | |||
| " \n", | |||
| " self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n", | |||
| "\n", | |||
| " self.loss_fn = nn.CrossEntropyLoss()\n", | |||
| "\n", | |||
| " def get_query(self, query):\n", | |||
| " device = query.device\n", | |||
| " return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n", | |||
| " torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n", | |||
| " torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n", | |||
| " torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n", | |||
| " query, \n", | |||
| " torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n", | |||
| " self.pre_seq_len = pre_seq_len\n", | |||
| " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", | |||
| " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", | |||
| " \n", | |||
| " def get_prompt(self, batch_size):\n", | |||
| " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", | |||
| " prompts = self.prefix_encoder(prefix_tokens)\n", | |||
| " return prompts\n", | |||
| "\n", | |||
| " def forward(self, input_ids):\n", | |||
| " input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n", | |||
| " attention_mask = input_ids != self.pad_token_id\n", | |||
| " def forward(self, input_ids, attention_mask, labels):\n", | |||
| " \n", | |||
| " bz = input_ids.shape[0]\n", | |||
| " inputs_embeds = input_ids.clone()\n", | |||
| " inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n", | |||
| " inputs_embeds = self.embeddings(inputs_embeds)\n", | |||
| "\n", | |||
| " blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n", | |||
| " replace_embeds = self.prompt_encoder(input_ids.device)\n", | |||
| " for bidx in range(bz):\n", | |||
| " for i in range(self.spell_length):\n", | |||
| " inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n", | |||
| " batch_size = input_ids.shape[0]\n", | |||
| " raw_embedding = self.embeddings(input_ids)\n", | |||
| " \n", | |||
| " return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n", | |||
| " prompts = self.get_prompt(batch_size=batch_size)\n", | |||
| " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", | |||
| " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", | |||
| " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", | |||
| "\n", | |||
| " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", | |||
| " attention_mask=attention_mask, labels=labels)\n", | |||
| " return outputs\n", | |||
| "\n", | |||
| " def train_step(self, input_ids, attention_mask, labels):\n", | |||
| " pred = self(input_ids).logits\n", | |||
| " return {\"loss\": self.loss_fn(pred, labels)}\n", | |||
| " return {\"loss\": self(input_ids, attention_mask, labels).loss}\n", | |||
| "\n", | |||
| " def evaluate_step(self, input_ids, attention_mask, labels):\n", | |||
| " pred = self(input_ids).logits\n", | |||
| " pred = self(input_ids, attention_mask, labels).logits\n", | |||
| " pred = torch.max(pred, dim=-1)[1]\n", | |||
| " return {\"pred\": pred, \"target\": labels}" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "execution_count": 17, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stderr", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n", | |||
| "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n", | |||
| "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |||
| "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |||
| "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n", | |||
| "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", | |||
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |||
| ] | |||
| }, | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "init prompt encoder...\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", | |||
| "\n", | |||
| "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", | |||
| "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n", | |||
| "\n", | |||
| "# Generally, simple classification tasks prefer shorter prompts (less than 20)\n", | |||
| "\n", | |||
| "optimizers = AdamW(params=model.parameters(), lr=5e-4)" | |||
| "optimizers = AdamW(params=model.parameters(), lr=5e-3)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "execution_count": 5, | |||
| "metadata": { | |||
| "scrolled": false | |||
| }, | |||
| @@ -209,13 +153,14 @@ | |||
| "name": "stderr", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", | |||
| "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "f82d2ccee863492582f94552654482f9", | |||
| "model_id": "1b73650d43f245ac8a5501dc91c6fe8c", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| @@ -230,46 +175,28 @@ | |||
| "source": [ | |||
| "from datasets import load_dataset, load_metric\n", | |||
| "\n", | |||
| "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" | |||
| "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n", | |||
| "\n", | |||
| "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 7, | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "cf324902e7b94ea9be709b979b425c96", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| " 0%| | 0/68 [00:00<?, ?ba/s]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "21eb6203ec6f4592b8cb8530a59eda49", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| "text/plain": [ | |||
| " 0%| | 0/1 [00:00<?, ?ba/s]" | |||
| ] | |||
| }, | |||
| "metadata": {}, | |||
| "output_type": "display_data" | |||
| "name": "stderr", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", | |||
| "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n" | |||
| ] | |||
| }, | |||
| { | |||
| "data": { | |||
| "application/vnd.jupyter.widget-view+json": { | |||
| "model_id": "05b83c4b1a9f44aea805788e1e52db78", | |||
| "model_id": "0be84915c90f460896b8e67299e09df4", | |||
| "version_major": 2, | |||
| "version_minor": 0 | |||
| }, | |||
| @@ -283,14 +210,14 @@ | |||
| ], | |||
| "source": [ | |||
| "def preprocess_function(examples):\n", | |||
| " return model.tokenizer(examples['sentence'], truncation=True)\n", | |||
| " return tokenizer(examples['sentence'], truncation=True)\n", | |||
| "\n", | |||
| "encoded_dataset = dataset.map(preprocess_function, batched=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 8, | |||
| "execution_count": 7, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -309,7 +236,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 9, | |||
| "execution_count": 8, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -335,7 +262,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 10, | |||
| "execution_count": 9, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -349,7 +276,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 11, | |||
| "execution_count": 18, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| @@ -367,7 +294,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 12, | |||
| "execution_count": 19, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| @@ -410,7 +337,7 @@ | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 13, | |||
| "execution_count": 20, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| @@ -436,10 +363,10 @@ | |||
| { | |||
| "data": { | |||
| "text/plain": [ | |||
| "{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}" | |||
| "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}" | |||
| ] | |||
| }, | |||
| "execution_count": 13, | |||
| "execution_count": 20, | |||
| "metadata": {}, | |||
| "output_type": "execute_result" | |||
| } | |||