|
@@ -0,0 +1,922 @@ |
|
|
|
|
|
{ |
|
|
|
|
|
"cells": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "OYlaRwNu7ojq" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# **Homework 2 Phoneme Classification**\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"* Slides: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing\n", |
|
|
|
|
|
"* Kaggle: https://www.kaggle.com/c/ml2022spring-hw2\n", |
|
|
|
|
|
"* Video: TBA\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "mLQI0mNcmM-O", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "7d5b4d81-9438-4d50-8153-cd235c47ee21" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"Wed Feb 23 14:42:18 2022 \n", |
|
|
|
|
|
"+-----------------------------------------------------------------------------+\n", |
|
|
|
|
|
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", |
|
|
|
|
|
"|-------------------------------+----------------------+----------------------+\n", |
|
|
|
|
|
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", |
|
|
|
|
|
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", |
|
|
|
|
|
"| | | MIG M. |\n", |
|
|
|
|
|
"|===============================+======================+======================|\n", |
|
|
|
|
|
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", |
|
|
|
|
|
"| N/A 30C P8 29W / 149W | 0MiB / 11441MiB | 0% Default |\n", |
|
|
|
|
|
"| | | N/A |\n", |
|
|
|
|
|
"+-------------------------------+----------------------+----------------------+\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
"+-----------------------------------------------------------------------------+\n", |
|
|
|
|
|
"| Processes: |\n", |
|
|
|
|
|
"| GPU GI CI PID Type Process name GPU Memory |\n", |
|
|
|
|
|
"| ID ID Usage |\n", |
|
|
|
|
|
"|=============================================================================|\n", |
|
|
|
|
|
"| No running processes found |\n", |
|
|
|
|
|
"+-----------------------------------------------------------------------------+\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"!nvidia-smi" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "KVUGfWTo7_Oj" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Download Data\n", |
|
|
|
|
|
"Download data from google drive, then unzip it.\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"You should have\n", |
|
|
|
|
|
"- `libriphone/train_split.txt`\n", |
|
|
|
|
|
"- `libriphone/train_labels`\n", |
|
|
|
|
|
"- `libriphone/test_split.txt`\n", |
|
|
|
|
|
"- `libriphone/feat/train/*.pt`: training feature<br>\n", |
|
|
|
|
|
"- `libriphone/feat/test/*.pt`: testing feature<br>\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"after running the following block.\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"> **Notes: if the google drive link is dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw2/data) and upload it to the workspace**\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Bj5jYXsD9Ef3" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### Download train/test metadata" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"id": "OzkiMEcC3Foq", |
|
|
|
|
|
"outputId": "0c8644b9-8a1e-4d23-de78-bef46c22bb1f" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.2.1)\n", |
|
|
|
|
|
"Collecting gdown\n", |
|
|
|
|
|
" Downloading gdown-4.3.1.tar.gz (13 kB)\n", |
|
|
|
|
|
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", |
|
|
|
|
|
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", |
|
|
|
|
|
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", |
|
|
|
|
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.62.3)\n", |
|
|
|
|
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.6.0)\n", |
|
|
|
|
|
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n", |
|
|
|
|
|
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n", |
|
|
|
|
|
"Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n", |
|
|
|
|
|
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n", |
|
|
|
|
|
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n", |
|
|
|
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2021.10.8)\n", |
|
|
|
|
|
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n", |
|
|
|
|
|
"Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n", |
|
|
|
|
|
"Building wheels for collected packages: gdown\n", |
|
|
|
|
|
" Building wheel for gdown (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", |
|
|
|
|
|
" Created wheel for gdown: filename=gdown-4.3.1-py3-none-any.whl size=14493 sha256=3e61ade9bf1f058ddfc642d7504028cd8dfce8c1ae9220b30d487f42c64b58e8\n", |
|
|
|
|
|
" Stored in directory: /root/.cache/pip/wheels/39/13/56/88209f07bace2c1af0614ee3326de4a00aad74afb0f4be921d\n", |
|
|
|
|
|
"Successfully built gdown\n", |
|
|
|
|
|
"Installing collected packages: gdown\n", |
|
|
|
|
|
" Attempting uninstall: gdown\n", |
|
|
|
|
|
" Found existing installation: gdown 4.2.1\n", |
|
|
|
|
|
" Uninstalling gdown-4.2.1:\n", |
|
|
|
|
|
" Successfully uninstalled gdown-4.2.1\n", |
|
|
|
|
|
"Successfully installed gdown-4.3.1\n", |
|
|
|
|
|
"Downloading...\n", |
|
|
|
|
|
"From: https://drive.google.com/uc?id=1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc\n", |
|
|
|
|
|
"To: /content/libriphone.zip\n", |
|
|
|
|
|
"100% 479M/479M [00:04<00:00, 101MB/s]\n", |
|
|
|
|
|
"feat test_split.txt train_labels.txt\ttrain_split.txt\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"!pip install --upgrade gdown\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Main link\n", |
|
|
|
|
|
"!gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Backup link 1\n", |
|
|
|
|
|
"# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# Bqckup link 2\n", |
|
|
|
|
|
"# !wget -O libriphone.zip \"https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1\"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"!unzip -q libriphone.zip\n", |
|
|
|
|
|
"!ls libriphone" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "_L_4anls8Drv" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"### Preparing Data" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "po4N3C-AWuWl" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"**Helper functions to pre-process the training data from raw MFCC features of each utterance.**\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"A phoneme may span several frames and is dependent to past and future frames. \\\n", |
|
|
|
|
|
"Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "IJjLT8em-y9G" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import os\n", |
|
|
|
|
|
"import random\n", |
|
|
|
|
|
"import pandas as pd\n", |
|
|
|
|
|
"import torch\n", |
|
|
|
|
|
"from tqdm import tqdm\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def load_feat(path):\n", |
|
|
|
|
|
" feat = torch.load(path)\n", |
|
|
|
|
|
" return feat\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def shift(x, n):\n", |
|
|
|
|
|
" if n < 0:\n", |
|
|
|
|
|
" left = x[0].repeat(-n, 1)\n", |
|
|
|
|
|
" right = x[:n]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" elif n > 0:\n", |
|
|
|
|
|
" right = x[-1].repeat(n, 1)\n", |
|
|
|
|
|
" left = x[n:]\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" return x\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" return torch.cat((left, right), dim=0)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def concat_feat(x, concat_n):\n", |
|
|
|
|
|
" assert concat_n % 2 == 1 # n must be odd\n", |
|
|
|
|
|
" if concat_n < 2:\n", |
|
|
|
|
|
" return x\n", |
|
|
|
|
|
" seq_len, feature_dim = x.size(0), x.size(1)\n", |
|
|
|
|
|
" x = x.repeat(1, concat_n) \n", |
|
|
|
|
|
" x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n", |
|
|
|
|
|
" mid = (concat_n // 2)\n", |
|
|
|
|
|
" for r_idx in range(1, mid+1):\n", |
|
|
|
|
|
" x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n", |
|
|
|
|
|
" x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):\n", |
|
|
|
|
|
" class_num = 41 # NOTE: pre-computed, should not need change\n", |
|
|
|
|
|
" mode = 'train' if (split == 'train' or split == 'val') else 'test'\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" label_dict = {}\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" for line in phone_file:\n", |
|
|
|
|
|
" line = line.strip('\\n').split(' ')\n", |
|
|
|
|
|
" label_dict[line[0]] = [int(p) for p in line[1:]]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" if split == 'train' or split == 'val':\n", |
|
|
|
|
|
" # split training and validation data\n", |
|
|
|
|
|
" usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n", |
|
|
|
|
|
" random.seed(train_val_seed)\n", |
|
|
|
|
|
" random.shuffle(usage_list)\n", |
|
|
|
|
|
" percent = int(len(usage_list) * train_ratio)\n", |
|
|
|
|
|
" usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]\n", |
|
|
|
|
|
" elif split == 'test':\n", |
|
|
|
|
|
" usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" usage_list = [line.strip('\\n') for line in usage_list]\n", |
|
|
|
|
|
" print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" max_len = 3000000\n", |
|
|
|
|
|
" X = torch.empty(max_len, 39 * concat_nframes)\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" y = torch.empty(max_len, dtype=torch.long)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" idx = 0\n", |
|
|
|
|
|
" for i, fname in tqdm(enumerate(usage_list)):\n", |
|
|
|
|
|
" feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n", |
|
|
|
|
|
" cur_len = len(feat)\n", |
|
|
|
|
|
" feat = concat_feat(feat, concat_nframes)\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" label = torch.LongTensor(label_dict[fname])\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" X[idx: idx + cur_len, :] = feat\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" y[idx: idx + cur_len] = label\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" idx += cur_len\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" X = X[:idx, :]\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" y = y[:idx]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" print(f'[INFO] {split} set')\n", |
|
|
|
|
|
" print(X.shape)\n", |
|
|
|
|
|
" if mode != 'test':\n", |
|
|
|
|
|
" print(y.shape)\n", |
|
|
|
|
|
" return X, y\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" return X\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "us5XW_x6udZQ" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Define Dataset" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Fjf5EcmJtf4e" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import torch\n", |
|
|
|
|
|
"from torch.utils.data import Dataset\n", |
|
|
|
|
|
"from torch.utils.data import DataLoader\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"class LibriDataset(Dataset):\n", |
|
|
|
|
|
" def __init__(self, X, y=None):\n", |
|
|
|
|
|
" self.data = X\n", |
|
|
|
|
|
" if y is not None:\n", |
|
|
|
|
|
" self.label = torch.LongTensor(y)\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" self.label = None\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __getitem__(self, idx):\n", |
|
|
|
|
|
" if self.label is not None:\n", |
|
|
|
|
|
" return self.data[idx], self.label[idx]\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" return self.data[idx]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def __len__(self):\n", |
|
|
|
|
|
" return len(self.data)\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "IRqKNvNZwe3V" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Define Model" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "Bg-GRd7ywdrL" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import torch\n", |
|
|
|
|
|
"import torch.nn as nn\n", |
|
|
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"class BasicBlock(nn.Module):\n", |
|
|
|
|
|
" def __init__(self, input_dim, output_dim):\n", |
|
|
|
|
|
" super(BasicBlock, self).__init__()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" self.block = nn.Sequential(\n", |
|
|
|
|
|
" nn.Linear(input_dim, output_dim),\n", |
|
|
|
|
|
" nn.ReLU(),\n", |
|
|
|
|
|
" )\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def forward(self, x):\n", |
|
|
|
|
|
" x = self.block(x)\n", |
|
|
|
|
|
" return x\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"class Classifier(nn.Module):\n", |
|
|
|
|
|
" def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n", |
|
|
|
|
|
" super(Classifier, self).__init__()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" self.fc = nn.Sequential(\n", |
|
|
|
|
|
" BasicBlock(input_dim, hidden_dim),\n", |
|
|
|
|
|
" *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],\n", |
|
|
|
|
|
" nn.Linear(hidden_dim, output_dim)\n", |
|
|
|
|
|
" )\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" def forward(self, x):\n", |
|
|
|
|
|
" x = self.fc(x)\n", |
|
|
|
|
|
" return x" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Hyper-parameters" |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "TlIq8JeqvvHC" |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# data prarameters\n", |
|
|
|
|
|
"concat_nframes = 1 # the number of frames to concat with, n must be odd (total 2k+1 = n frames)\n", |
|
|
|
|
|
"train_ratio = 0.8 # the ratio of data used for training, the rest will be used for validation\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# training parameters\n", |
|
|
|
|
|
"seed = 0 # random seed\n", |
|
|
|
|
|
"batch_size = 512 # batch size\n", |
|
|
|
|
|
"num_epoch = 5 # the number of training epoch\n", |
|
|
|
|
|
"learning_rate = 0.0001 # learning rate\n", |
|
|
|
|
|
"model_path = './model.ckpt' # the path where the checkpoint will be saved\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# model parameters\n", |
|
|
|
|
|
"input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value\n", |
|
|
|
|
|
"hidden_layers = 1 # the number of hidden layers\n", |
|
|
|
|
|
"hidden_dim = 256 # the hidden dim" |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "iIHn79Iav1ri" |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Prepare dataset and model" |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "IIUFRgG5yoDn" |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import gc\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# preprocess data\n", |
|
|
|
|
|
"train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n", |
|
|
|
|
|
"val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# get dataset\n", |
|
|
|
|
|
"train_set = LibriDataset(train_X, train_y)\n", |
|
|
|
|
|
"val_set = LibriDataset(val_X, val_y)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# remove raw feature to save memory\n", |
|
|
|
|
|
"del train_X, train_y, val_X, val_y\n", |
|
|
|
|
|
"gc.collect()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# get dataloader\n", |
|
|
|
|
|
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", |
|
|
|
|
|
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)" |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "c1zI3v5jyrDn", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "7fd90470-ef44-404a-a043-cb8e83fd1801" |
|
|
|
|
|
}, |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[Dataset] - # phone classes: 41, number of utterances for train: 3428\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"3428it [00:01, 2436.46it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[INFO] train set\n", |
|
|
|
|
|
"torch.Size([2116368, 39])\n", |
|
|
|
|
|
"torch.Size([2116368])\n", |
|
|
|
|
|
"[Dataset] - # phone classes: 41, number of utterances for val: 858\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"858it [00:00, 2322.86it/s]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[INFO] val set\n", |
|
|
|
|
|
"torch.Size([527790, 39])\n", |
|
|
|
|
|
"torch.Size([527790])\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "CfRUEgC0GxUV", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "75dcb672-a97d-43ff-b0f1-cb23d27fe65a" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"DEVICE: cuda:0\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", |
|
|
|
|
|
"print(f'DEVICE: {device}')" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "88xPiUnm0tAd" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"import numpy as np\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"#fix seed\n", |
|
|
|
|
|
"def same_seeds(seed):\n", |
|
|
|
|
|
" torch.manual_seed(seed)\n", |
|
|
|
|
|
" if torch.cuda.is_available():\n", |
|
|
|
|
|
" torch.cuda.manual_seed(seed)\n", |
|
|
|
|
|
" torch.cuda.manual_seed_all(seed) \n", |
|
|
|
|
|
" np.random.seed(seed) \n", |
|
|
|
|
|
" torch.backends.cudnn.benchmark = False\n", |
|
|
|
|
|
" torch.backends.cudnn.deterministic = True" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "QTp3ZXg1yO9Y" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# fix random seed\n", |
|
|
|
|
|
"same_seeds(seed)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# create model, define a loss function, and optimizer\n", |
|
|
|
|
|
"model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n", |
|
|
|
|
|
"criterion = nn.CrossEntropyLoss() \n", |
|
|
|
|
|
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Training" |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "pwWH1KIqzxEr" |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "CdMWsBs7zzNs", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "cfb0046d-52a5-4a90-c073-f4091e8b230e" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 4134/4134 [00:23<00:00, 178.15it/s]\n", |
|
|
|
|
|
"100%|██████████| 1031/1031 [00:03<00:00, 286.39it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[001/005] Train Acc: 0.335711 Loss: 2.508261 | Val Acc: 0.401501 loss: 2.131011\n", |
|
|
|
|
|
"saving model with acc 0.402\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 4134/4134 [00:23<00:00, 175.35it/s]\n", |
|
|
|
|
|
"100%|██████████| 1031/1031 [00:03<00:00, 281.23it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[002/005] Train Acc: 0.380122 Loss: 2.270708 | Val Acc: 0.409386 loss: 2.085472\n", |
|
|
|
|
|
"saving model with acc 0.409\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 4134/4134 [00:23<00:00, 174.85it/s]\n", |
|
|
|
|
|
"100%|██████████| 1031/1031 [00:03<00:00, 271.75it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[003/005] Train Acc: 0.388530 Loss: 2.227409 | Val Acc: 0.414295 loss: 2.063263\n", |
|
|
|
|
|
"saving model with acc 0.414\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 4134/4134 [00:23<00:00, 177.08it/s]\n", |
|
|
|
|
|
"100%|██████████| 1031/1031 [00:03<00:00, 276.20it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[004/005] Train Acc: 0.393600 Loss: 2.201657 | Val Acc: 0.417833 loss: 2.046759\n", |
|
|
|
|
|
"saving model with acc 0.418\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 4134/4134 [00:23<00:00, 177.41it/s]\n", |
|
|
|
|
|
"100%|██████████| 1031/1031 [00:03<00:00, 280.50it/s]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[005/005] Train Acc: 0.397076 Loss: 2.184807 | Val Acc: 0.420696 loss: 2.035235\n", |
|
|
|
|
|
"saving model with acc 0.421\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"best_acc = 0.0\n", |
|
|
|
|
|
"for epoch in range(num_epoch):\n", |
|
|
|
|
|
" train_acc = 0.0\n", |
|
|
|
|
|
" train_loss = 0.0\n", |
|
|
|
|
|
" val_acc = 0.0\n", |
|
|
|
|
|
" val_loss = 0.0\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # training\n", |
|
|
|
|
|
" model.train() # set the model to training mode\n", |
|
|
|
|
|
" for i, batch in enumerate(tqdm(train_loader)):\n", |
|
|
|
|
|
" features, labels = batch\n", |
|
|
|
|
|
" features = features.to(device)\n", |
|
|
|
|
|
" labels = labels.to(device)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" optimizer.zero_grad() \n", |
|
|
|
|
|
" outputs = model(features) \n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" loss = criterion(outputs, labels)\n", |
|
|
|
|
|
" loss.backward() \n", |
|
|
|
|
|
" optimizer.step() \n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", |
|
|
|
|
|
" train_acc += (train_pred.detach() == labels.detach()).sum().item()\n", |
|
|
|
|
|
" train_loss += loss.item()\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" # validation\n", |
|
|
|
|
|
" if len(val_set) > 0:\n", |
|
|
|
|
|
" model.eval() # set the model to evaluation mode\n", |
|
|
|
|
|
" with torch.no_grad():\n", |
|
|
|
|
|
" for i, batch in enumerate(tqdm(val_loader)):\n", |
|
|
|
|
|
" features, labels = batch\n", |
|
|
|
|
|
" features = features.to(device)\n", |
|
|
|
|
|
" labels = labels.to(device)\n", |
|
|
|
|
|
" outputs = model(features)\n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" loss = criterion(outputs, labels) \n", |
|
|
|
|
|
" \n", |
|
|
|
|
|
" _, val_pred = torch.max(outputs, 1) \n", |
|
|
|
|
|
" val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n", |
|
|
|
|
|
" val_loss += loss.item()\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n", |
|
|
|
|
|
" epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n", |
|
|
|
|
|
" ))\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" # if the model improves, save a checkpoint at this epoch\n", |
|
|
|
|
|
" if val_acc > best_acc:\n", |
|
|
|
|
|
" best_acc = val_acc\n", |
|
|
|
|
|
" torch.save(model.state_dict(), model_path)\n", |
|
|
|
|
|
" print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n", |
|
|
|
|
|
" else:\n", |
|
|
|
|
|
" print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n", |
|
|
|
|
|
" epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n", |
|
|
|
|
|
" ))\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"# if not validating, save the last epoch\n", |
|
|
|
|
|
"if len(val_set) == 0:\n", |
|
|
|
|
|
" torch.save(model.state_dict(), model_path)\n", |
|
|
|
|
|
" print('saving model at last epoch')\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "ab33MxosWLmG", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "7ea97815-15cc-4afa-fa7e-b65460b91640" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "execute_result", |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"50" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"execution_count": 12 |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"del train_loader, val_loader\n", |
|
|
|
|
|
"gc.collect()" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "1Hi7jTn3PX-m" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"## Testing\n", |
|
|
|
|
|
"Create a testing dataset, and load model from the saved checkpoint." |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "VOG1Ou0PGrhc", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "81077fbc-a6ea-46b7-9a57-a690480fbb6b" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[Dataset] - # phone classes: 41, number of utterances for test: 1078\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"1078it [00:00, 2784.86it/s]" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stdout", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"[INFO] test set\n", |
|
|
|
|
|
"torch.Size([646268, 39])\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# load data\n", |
|
|
|
|
|
"test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n", |
|
|
|
|
|
"test_set = LibriDataset(test_X, None)\n", |
|
|
|
|
|
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "ay0Fu8Ovkdad", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "a29d1dbc-3222-4cec-8f84-04475b77cceb" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "execute_result", |
|
|
|
|
|
"data": { |
|
|
|
|
|
"text/plain": [ |
|
|
|
|
|
"<All keys matched successfully>" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
"metadata": {}, |
|
|
|
|
|
"execution_count": 14 |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"# load model\n", |
|
|
|
|
|
"model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n", |
|
|
|
|
|
"model.load_state_dict(torch.load(model_path))" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "zp-DV1p4r7Nz" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Make prediction." |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "84HU5GGjPqR0", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"base_uri": "https://localhost:8080/" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputId": "bbaaa8c5-d88c-4ef3-f7be-d75b208cd5df" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [ |
|
|
|
|
|
{ |
|
|
|
|
|
"output_type": "stream", |
|
|
|
|
|
"name": "stderr", |
|
|
|
|
|
"text": [ |
|
|
|
|
|
"100%|██████████| 1263/1263 [00:02<00:00, 426.93it/s]\n" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"test_acc = 0.0\n", |
|
|
|
|
|
"test_lengths = 0\n", |
|
|
|
|
|
"pred = np.array([], dtype=np.int32)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"model.eval()\n", |
|
|
|
|
|
"with torch.no_grad():\n", |
|
|
|
|
|
" for i, batch in enumerate(tqdm(test_loader)):\n", |
|
|
|
|
|
" features = batch\n", |
|
|
|
|
|
" features = features.to(device)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" outputs = model(features)\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
" _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n", |
|
|
|
|
|
" pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n" |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "markdown", |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "wyZqy40Prz0v" |
|
|
|
|
|
}, |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"Write prediction to a CSV file.\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
"After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle." |
|
|
|
|
|
] |
|
|
|
|
|
}, |
|
|
|
|
|
{ |
|
|
|
|
|
"cell_type": "code", |
|
|
|
|
|
"execution_count": null, |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"id": "GuljYSPHcZir" |
|
|
|
|
|
}, |
|
|
|
|
|
"outputs": [], |
|
|
|
|
|
"source": [ |
|
|
|
|
|
"with open('prediction.csv', 'w') as f:\n", |
|
|
|
|
|
" f.write('Id,Class\\n')\n", |
|
|
|
|
|
" for i, y in enumerate(pred):\n", |
|
|
|
|
|
" f.write('{},{}\\n'.format(i, y))" |
|
|
|
|
|
] |
|
|
|
|
|
} |
|
|
|
|
|
], |
|
|
|
|
|
"metadata": { |
|
|
|
|
|
"accelerator": "GPU", |
|
|
|
|
|
"colab": { |
|
|
|
|
|
"collapsed_sections": [], |
|
|
|
|
|
"name": "ML2022Spring - HW2.ipynb", |
|
|
|
|
|
"provenance": [] |
|
|
|
|
|
}, |
|
|
|
|
|
"kernelspec": { |
|
|
|
|
|
"display_name": "Python 3", |
|
|
|
|
|
"name": "python3" |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
"nbformat": 4, |
|
|
|
|
|
"nbformat_minor": 0 |
|
|
|
|
|
} |