|
@@ -0,0 +1,576 @@ |
|
|
|
|
|
{ |
|
|
|
|
|
"nbformat": 4, |
|
|
|
|
|
"nbformat_minor": 0, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"accelerator": "GPU", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"name": "hw7_bert", |
|
|
|
|
|
"provenance": [], |
|
|
|
|
|
"collapsed_sections": [], |
|
|
|
|
|
"toc_visible": true, |
|
|
|
|
|
"include_colab_link": true |
|
|
|
|
|
}, |
|
|
|
|
|
"kernelspec": { |
|
|
|
|
|
"display_name": "Python 3", |
|
|
|
|
|
"name": "python3" |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
"cells": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "view-in-github", |
|
|
|
|
|
"colab_type": "text" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"<a href=\"https://colab.research.google.com/github/ga642381/ML2021-Spring/blob/main/HW07/HW07.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "xvSGDbExff_I" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# **Homework 7 - Bert (Question Answering)**\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"If you have any questions, feel free to email us at ntu-ml-2021spring-ta@googlegroups.com\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"Slide: [Link](https://docs.google.com/presentation/d/1aQoWogAQo_xVJvMQMrGaYiWzuyfO0QyLLAhiMwFyS2w) Kaggle: [Link](https://www.kaggle.com/c/ml2021-spring-hw7) Data: [Link](https://drive.google.com/uc?id=1znKmX08v9Fygp-dgwo7BKiLIf2qL1FH1)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "WGOr_eS3wJJf" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Task description\n", |
|
|
|
|
|
"- Chinese Extractive Question Answering\n", |
|
|
|
|
|
" - Input: Paragraph + Question\n", |
|
|
|
|
|
" - Output: Answer\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"- Objective: Learn how to fine tune a pretrained model on downstream task using transformers\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"- Todo\n", |
|
|
|
|
|
" - Fine tune a pretrained chinese BERT model\n", |
|
|
|
|
|
" - Change hyperparameters (e.g. doc_stride)\n", |
|
|
|
|
|
" - Apply linear learning rate decay\n", |
|
|
|
|
|
" - Try other pretrained models\n", |
|
|
|
|
|
" - Improve preprocessing\n", |
|
|
|
|
|
" - Improve postprocessing\n", |
|
|
|
|
|
"- Training tips\n", |
|
|
|
|
|
" - Automatic mixed precision\n", |
|
|
|
|
|
" - Gradient accumulation\n", |
|
|
|
|
|
" - Ensemble\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"- Estimated training time (tesla t4 with automatic mixed precision enabled)\n", |
|
|
|
|
|
" - Simple: 8mins\n", |
|
|
|
|
|
" - Medium: 8mins\n", |
|
|
|
|
|
" - Strong: 25mins\n", |
|
|
|
|
|
" - Boss: 2hrs\n", |
|
|
|
|
|
" " |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "TJ1fSAJE2oaC" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Download Dataset" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "YPrc4Eie9Yo5" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# Download link 1\n", |
|
|
|
|
|
"!gdown --id '1znKmX08v9Fygp-dgwo7BKiLIf2qL1FH1' --output hw7_data.zip\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Download Link 2 (if the above link fails) \n", |
|
|
|
|
|
"# !gdown --id '1pOu3FdPdvzielUZyggeD7KDnVy9iW1uC' --output hw7_data.zip\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"!unzip -o hw7_data.zip\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# For this HW, K80 < P4 < T4 < P100 <= T4(fp16) < V100\n", |
|
|
|
|
|
"!nvidia-smi" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "TevOvhC03m0h" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Install transformers\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"Documentation for the toolkit: https://huggingface.co/transformers/" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "tbxWFX_jpDom" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# You are allowed to change version of transformers or use other toolkits\n", |
|
|
|
|
|
"!pip install transformers==4.5.0" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "8dKM4yCh4LI_" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Import Packages" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "WOTHHtWJoahe" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import json\n", |
|
|
|
|
|
"import numpy as np\n", |
|
|
|
|
|
"import random\n", |
|
|
|
|
|
"import torch\n", |
|
|
|
|
|
"from torch.utils.data import DataLoader, Dataset \n", |
|
|
|
|
|
"from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"from tqdm.auto import tqdm\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Fix random seed for reproducibility\n", |
|
|
|
|
|
"def same_seeds(seed):\n", |
|
|
|
|
|
"\t torch.manual_seed(seed)\n", |
|
|
|
|
|
"\t if torch.cuda.is_available():\n", |
|
|
|
|
|
"\t\t torch.cuda.manual_seed(seed)\n", |
|
|
|
|
|
"\t\t torch.cuda.manual_seed_all(seed)\n", |
|
|
|
|
|
"\t np.random.seed(seed)\n", |
|
|
|
|
|
"\t random.seed(seed)\n", |
|
|
|
|
|
"\t torch.backends.cudnn.benchmark = False\n", |
|
|
|
|
|
"\t torch.backends.cudnn.deterministic = True\n", |
|
|
|
|
|
"same_seeds(0)" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": 16, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "7pBtSZP1SKQO" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# Change \"fp16_training\" to True to support automatic mixed precision training (fp16)\t\n", |
|
|
|
|
|
"fp16_training = False\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"if fp16_training:\n", |
|
|
|
|
|
" !pip install accelerate==0.2.0\n", |
|
|
|
|
|
" from accelerate import Accelerator\n", |
|
|
|
|
|
" accelerator = Accelerator(fp16=True)\n", |
|
|
|
|
|
" device = accelerator.device\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Documentation for the toolkit: https://huggingface.co/docs/accelerate/" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": 17, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "2YgXHuVLp_6j" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Load Model and Tokenizer\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" " |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "xyBCYGjAp3ym" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"model = BertForQuestionAnswering.from_pretrained(\"bert-base-chinese\").to(device)\n", |
|
|
|
|
|
"tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-chinese\")\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# You can safely ignore the warning message (it pops up because new prediction heads for QA are initialized randomly)" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "3Td-GTmk5OW4" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Read Data\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"- Training set: 26935 QA pairs\n", |
|
|
|
|
|
"- Dev set: 3523 QA pairs\n", |
|
|
|
|
|
"- Test set: 3492 QA pairs\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"- {train/dev/test}_questions:\t\n", |
|
|
|
|
|
" - List of dicts with the following keys:\n", |
|
|
|
|
|
" - id (int)\n", |
|
|
|
|
|
" - paragraph_id (int)\n", |
|
|
|
|
|
" - question_text (string)\n", |
|
|
|
|
|
" - answer_text (string)\n", |
|
|
|
|
|
" - answer_start (int)\n", |
|
|
|
|
|
" - answer_end (int)\n", |
|
|
|
|
|
"- {train/dev/test}_paragraphs: \n", |
|
|
|
|
|
" - List of strings\n", |
|
|
|
|
|
" - paragraph_ids in questions correspond to indexs in paragraphs\n", |
|
|
|
|
|
" - A paragraph may be used by several questions " |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "NvX7hlepogvu" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def read_data(file):\n", |
|
|
|
|
|
" with open(file, 'r', encoding=\"utf-8\") as reader:\n", |
|
|
|
|
|
" data = json.load(reader)\n", |
|
|
|
|
|
" return data[\"questions\"], data[\"paragraphs\"]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"train_questions, train_paragraphs = read_data(\"hw7_train.json\")\n", |
|
|
|
|
|
"dev_questions, dev_paragraphs = read_data(\"hw7_dev.json\")\n", |
|
|
|
|
|
"test_questions, test_paragraphs = read_data(\"hw7_test.json\")" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": 19, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Fm0rpTHq0e4N" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Tokenize Data" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "rTZ6B70Hoxie" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# Tokenize questions and paragraphs separately\n", |
|
|
|
|
|
"# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ \n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"train_questions_tokenized = tokenizer([train_question[\"question_text\"] for train_question in train_questions], add_special_tokens=False)\n", |
|
|
|
|
|
"dev_questions_tokenized = tokenizer([dev_question[\"question_text\"] for dev_question in dev_questions], add_special_tokens=False)\n", |
|
|
|
|
|
"test_questions_tokenized = tokenizer([test_question[\"question_text\"] for test_question in test_questions], add_special_tokens=False) \n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)\n", |
|
|
|
|
|
"dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)\n", |
|
|
|
|
|
"test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# You can safely ignore the warning message as tokenized sequences will be futher processed in datset __getitem__ before passing to model" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Ws8c8_4d5UCI" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Dataset and Dataloader" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Xjooag-Swnuh" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"class QA_Dataset(Dataset):\n", |
|
|
|
|
|
" def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):\n", |
|
|
|
|
|
" self.split = split\n", |
|
|
|
|
|
" self.questions = questions\n", |
|
|
|
|
|
" self.tokenized_questions = tokenized_questions\n", |
|
|
|
|
|
" self.tokenized_paragraphs = tokenized_paragraphs\n", |
|
|
|
|
|
" self.max_question_len = 40\n", |
|
|
|
|
|
" self.max_paragraph_len = 150\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" ##### TODO: Change value of doc_stride #####\n", |
|
|
|
|
|
" self.doc_stride = 150\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]\n", |
|
|
|
|
|
" self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __len__(self):\n", |
|
|
|
|
|
" return len(self.questions)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __getitem__(self, idx):\n", |
|
|
|
|
|
" question = self.questions[idx]\n", |
|
|
|
|
|
" tokenized_question = self.tokenized_questions[idx]\n", |
|
|
|
|
|
" tokenized_paragraph = self.tokenized_paragraphs[question[\"paragraph_id\"]]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" ##### TODO: Preprocessing #####\n", |
|
|
|
|
|
" # Hint: How to prevent model from learning something it should not learn\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" if self.split == \"train\":\n", |
|
|
|
|
|
" # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph \n", |
|
|
|
|
|
" answer_start_token = tokenized_paragraph.char_to_token(question[\"answer_start\"])\n", |
|
|
|
|
|
" answer_end_token = tokenized_paragraph.char_to_token(question[\"answer_end\"])\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # A single window is obtained by slicing the portion of paragraph containing the answer\n", |
|
|
|
|
|
" mid = (answer_start_token + answer_end_token) // 2\n", |
|
|
|
|
|
" paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))\n", |
|
|
|
|
|
" paragraph_end = paragraph_start + self.max_paragraph_len\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n", |
|
|
|
|
|
" input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] \n", |
|
|
|
|
|
" input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]\t\t\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window \n", |
|
|
|
|
|
" answer_start_token += len(input_ids_question) - paragraph_start\n", |
|
|
|
|
|
" answer_end_token += len(input_ids_question) - paragraph_start\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Pad sequence and obtain inputs to model \n", |
|
|
|
|
|
" input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n", |
|
|
|
|
|
" return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # Validation/Testing\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" input_ids_list, token_type_ids_list, attention_mask_list = [], [], []\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Paragraph is split into several windows, each with start positions separated by step \"doc_stride\"\n", |
|
|
|
|
|
" for i in range(0, len(tokenized_paragraph), self.doc_stride):\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)\n", |
|
|
|
|
|
" input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]\n", |
|
|
|
|
|
" input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Pad sequence and obtain inputs to model\n", |
|
|
|
|
|
" input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" input_ids_list.append(input_ids)\n", |
|
|
|
|
|
" token_type_ids_list.append(token_type_ids)\n", |
|
|
|
|
|
" attention_mask_list.append(attention_mask)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def padding(self, input_ids_question, input_ids_paragraph):\n", |
|
|
|
|
|
" # Pad zeros if sequence length is shorter than max_seq_len\n", |
|
|
|
|
|
" padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)\n", |
|
|
|
|
|
" # Indices of input sequence tokens in the vocabulary\n", |
|
|
|
|
|
" input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len\n", |
|
|
|
|
|
" # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]\n", |
|
|
|
|
|
" token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len\n", |
|
|
|
|
|
" # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]\n", |
|
|
|
|
|
" attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" return input_ids, token_type_ids, attention_mask\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"train_set = QA_Dataset(\"train\", train_questions, train_questions_tokenized, train_paragraphs_tokenized)\n", |
|
|
|
|
|
"dev_set = QA_Dataset(\"dev\", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)\n", |
|
|
|
|
|
"test_set = QA_Dataset(\"test\", test_questions, test_questions_tokenized, test_paragraphs_tokenized)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"train_batch_size = 16\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Note: Do NOT change batch size of dev_loader / test_loader !\n", |
|
|
|
|
|
"# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair\n", |
|
|
|
|
|
"train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)\n", |
|
|
|
|
|
"dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)\n", |
|
|
|
|
|
"test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": 21, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "5_H1kqhR8CdM" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Function for Evaluation" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "SqeA3PLPxOHu" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"def evaluate(data, output):\n", |
|
|
|
|
|
" ##### TODO: Postprocessing #####\n", |
|
|
|
|
|
" # There is a bug and room for improvement in postprocessing \n", |
|
|
|
|
|
" # Hint: Open your prediction file to see what is wrong \n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" answer = ''\n", |
|
|
|
|
|
" max_prob = float('-inf')\n", |
|
|
|
|
|
" num_of_windows = data[0].shape[1]\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" for k in range(num_of_windows):\n", |
|
|
|
|
|
" # Obtain answer by choosing the most probable start position / end position\n", |
|
|
|
|
|
" start_prob, start_index = torch.max(output.start_logits[k], dim=0)\n", |
|
|
|
|
|
" end_prob, end_index = torch.max(output.end_logits[k], dim=0)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Probability of answer is calculated as sum of start_prob and end_prob\n", |
|
|
|
|
|
" prob = start_prob + end_prob\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Replace answer if calculated probability is larger than previous windows\n", |
|
|
|
|
|
" if prob > max_prob:\n", |
|
|
|
|
|
" max_prob = prob\n", |
|
|
|
|
|
" # Convert tokens to chars (e.g. [1920, 7032] --> \"大 金\")\n", |
|
|
|
|
|
" answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Remove spaces in answer (e.g. \"大 金\" --> \"大金\")\n", |
|
|
|
|
|
" return answer.replace(' ','')" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": 22, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "rzHQit6eMnKG" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Training" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "3Q-B6ka7xoCM" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"num_epoch = 1 \n", |
|
|
|
|
|
"validation = True\n", |
|
|
|
|
|
"logging_step = 100\n", |
|
|
|
|
|
"learning_rate = 1e-4\n", |
|
|
|
|
|
"optimizer = AdamW(model.parameters(), lr=learning_rate)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"if fp16_training:\n", |
|
|
|
|
|
" model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) \n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"model.train()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"print(\"Start Training ...\")\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"for epoch in range(num_epoch):\n", |
|
|
|
|
|
" step = 1\n", |
|
|
|
|
|
" train_loss = train_acc = 0\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" for data in tqdm(train_loader):\t\n", |
|
|
|
|
|
" # Load all data into GPU\n", |
|
|
|
|
|
" data = [i.to(device) for i in data]\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only \"input_ids\" is mandatory)\n", |
|
|
|
|
|
" # Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided) \n", |
|
|
|
|
|
" output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # Choose the most probable start position / end position\n", |
|
|
|
|
|
" start_index = torch.argmax(output.start_logits, dim=1)\n", |
|
|
|
|
|
" end_index = torch.argmax(output.end_logits, dim=1)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Prediction is correct only if both start_index and end_index are correct\n", |
|
|
|
|
|
" train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()\n", |
|
|
|
|
|
" train_loss += output.loss\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" if fp16_training:\n", |
|
|
|
|
|
" accelerator.backward(output.loss)\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" output.loss.backward()\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" optimizer.step()\n", |
|
|
|
|
|
" optimizer.zero_grad()\n", |
|
|
|
|
|
" step += 1\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" ##### TODO: Apply linear learning rate decay #####\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # Print training loss and accuracy over past logging step\n", |
|
|
|
|
|
" if step % logging_step == 0:\n", |
|
|
|
|
|
" print(f\"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}\")\n", |
|
|
|
|
|
" train_loss = train_acc = 0\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" if validation:\n", |
|
|
|
|
|
" print(\"Evaluating Dev Set ...\")\n", |
|
|
|
|
|
" model.eval()\n", |
|
|
|
|
|
" with torch.no_grad():\n", |
|
|
|
|
|
" dev_acc = 0\n", |
|
|
|
|
|
" for i, data in enumerate(tqdm(dev_loader)):\n", |
|
|
|
|
|
" output = model(input_ids=data[0].squeeze().to(device), token_type_ids=data[1].squeeze().to(device),\n", |
|
|
|
|
|
" attention_mask=data[2].squeeze().to(device))\n", |
|
|
|
|
|
" # prediction is correct only if answer text exactly matches\n", |
|
|
|
|
|
" dev_acc += evaluate(data, output) == dev_questions[i][\"answer_text\"]\n", |
|
|
|
|
|
" print(f\"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}\")\n", |
|
|
|
|
|
" model.train()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Save a model and its configuration file to the directory 「saved_model」 \n", |
|
|
|
|
|
"# i.e. there are two files under the direcory 「saved_model」: 「pytorch_model.bin」 and 「config.json」\n", |
|
|
|
|
|
"# Saved model can be re-loaded using 「model = BertForQuestionAnswering.from_pretrained(\"saved_model\")」\n", |
|
|
|
|
|
"print(\"Saving Model ...\")\n", |
|
|
|
|
|
"model_save_dir = \"saved_model\" \n", |
|
|
|
|
|
"model.save_pretrained(model_save_dir)" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "kMmdLOKBMsdE" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Testing" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "U5scNKC9xz0C" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"print(\"Evaluating Test Set ...\")\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"result = []\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"model.eval()\n", |
|
|
|
|
|
"with torch.no_grad():\n", |
|
|
|
|
|
" for data in tqdm(test_loader):\n", |
|
|
|
|
|
" output = model(input_ids=data[0].squeeze().to(device), token_type_ids=data[1].squeeze().to(device),\n", |
|
|
|
|
|
" attention_mask=data[2].squeeze().to(device))\n", |
|
|
|
|
|
" result.append(evaluate(data, output))\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"result_file = \"result.csv\"\n", |
|
|
|
|
|
"with open(result_file, 'w') as f:\t\n", |
|
|
|
|
|
"\t f.write(\"ID,Answer\\n\")\n", |
|
|
|
|
|
"\t for i, test_question in enumerate(test_questions):\n", |
|
|
|
|
|
" # Replace commas in answers with empty strings (since csv is separated by comma)\n", |
|
|
|
|
|
" # Answers in kaggle are processed in the same way\n", |
|
|
|
|
|
"\t\t f.write(f\"{test_question['id']},{result[i].replace(',','')}\\n\")\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"print(f\"Completed! Result is in {result_file}\")" |
|
|
|
|
|
], |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
} |
|
|
|
|
|
] |
|
|
|
|
|
} |