Browse Source

Update L2

main
Fafa-DL 2 years ago
parent
commit
61dfd16f9a
4 changed files with 924 additions and 0 deletions
  1. +922
    -0
      2022 ML/02 What to do if my network fails to train/ML2022Spring_HW2.ipynb
  2. BIN
      2022 ML/02 What to do if my network fails to train/hw2_slides 2022.pdf
  3. BIN
      2022 ML/02 What to do if my network fails to train/theory (v7).pdf
  4. +2
    -0
      README.md

+ 922
- 0
2022 ML/02 What to do if my network fails to train/ML2022Spring_HW2.ipynb View File

@@ -0,0 +1,922 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "OYlaRwNu7ojq"
},
"source": [
"# **Homework 2 Phoneme Classification**\n",
"\n",
"* Slides: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing\n",
"* Kaggle: https://www.kaggle.com/c/ml2022spring-hw2\n",
"* Video: TBA\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mLQI0mNcmM-O",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7d5b4d81-9438-4d50-8153-cd235c47ee21"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Wed Feb 23 14:42:18 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 30C P8 29W / 149W | 0MiB / 11441MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KVUGfWTo7_Oj"
},
"source": [
"## Download Data\n",
"Download data from google drive, then unzip it.\n",
"\n",
"You should have\n",
"- `libriphone/train_split.txt`\n",
"- `libriphone/train_labels`\n",
"- `libriphone/test_split.txt`\n",
"- `libriphone/feat/train/*.pt`: training feature<br>\n",
"- `libriphone/feat/test/*.pt`: testing feature<br>\n",
"\n",
"after running the following block.\n",
"\n",
"> **Notes: if the google drive link is dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw2/data) and upload it to the workspace**\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bj5jYXsD9Ef3"
},
"source": [
"### Download train/test metadata"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OzkiMEcC3Foq",
"outputId": "0c8644b9-8a1e-4d23-de78-bef46c22bb1f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.2.1)\n",
"Collecting gdown\n",
" Downloading gdown-4.3.1.tar.gz (13 kB)\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.62.3)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.6.0)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n",
"Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2021.10.8)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n",
"Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n",
"Building wheels for collected packages: gdown\n",
" Building wheel for gdown (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for gdown: filename=gdown-4.3.1-py3-none-any.whl size=14493 sha256=3e61ade9bf1f058ddfc642d7504028cd8dfce8c1ae9220b30d487f42c64b58e8\n",
" Stored in directory: /root/.cache/pip/wheels/39/13/56/88209f07bace2c1af0614ee3326de4a00aad74afb0f4be921d\n",
"Successfully built gdown\n",
"Installing collected packages: gdown\n",
" Attempting uninstall: gdown\n",
" Found existing installation: gdown 4.2.1\n",
" Uninstalling gdown-4.2.1:\n",
" Successfully uninstalled gdown-4.2.1\n",
"Successfully installed gdown-4.3.1\n",
"Downloading...\n",
"From: https://drive.google.com/uc?id=1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc\n",
"To: /content/libriphone.zip\n",
"100% 479M/479M [00:04<00:00, 101MB/s]\n",
"feat test_split.txt train_labels.txt\ttrain_split.txt\n"
]
}
],
"source": [
"!pip install --upgrade gdown\n",
"\n",
"# Main link\n",
"!gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip\n",
"\n",
"# Backup link 1\n",
"# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip\n",
"\n",
"# Bqckup link 2\n",
"# !wget -O libriphone.zip \"https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1\"\n",
"\n",
"!unzip -q libriphone.zip\n",
"!ls libriphone"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_L_4anls8Drv"
},
"source": [
"### Preparing Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "po4N3C-AWuWl"
},
"source": [
"**Helper functions to pre-process the training data from raw MFCC features of each utterance.**\n",
"\n",
"A phoneme may span several frames and is dependent to past and future frames. \\\n",
"Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.\n",
"\n",
"Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IJjLT8em-y9G"
},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import pandas as pd\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"def load_feat(path):\n",
" feat = torch.load(path)\n",
" return feat\n",
"\n",
"def shift(x, n):\n",
" if n < 0:\n",
" left = x[0].repeat(-n, 1)\n",
" right = x[:n]\n",
"\n",
" elif n > 0:\n",
" right = x[-1].repeat(n, 1)\n",
" left = x[n:]\n",
" else:\n",
" return x\n",
"\n",
" return torch.cat((left, right), dim=0)\n",
"\n",
"def concat_feat(x, concat_n):\n",
" assert concat_n % 2 == 1 # n must be odd\n",
" if concat_n < 2:\n",
" return x\n",
" seq_len, feature_dim = x.size(0), x.size(1)\n",
" x = x.repeat(1, concat_n) \n",
" x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n",
" mid = (concat_n // 2)\n",
" for r_idx in range(1, mid+1):\n",
" x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n",
" x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n",
"\n",
" return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n",
"\n",
"def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):\n",
" class_num = 41 # NOTE: pre-computed, should not need change\n",
" mode = 'train' if (split == 'train' or split == 'val') else 'test'\n",
"\n",
" label_dict = {}\n",
" if mode != 'test':\n",
" phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()\n",
"\n",
" for line in phone_file:\n",
" line = line.strip('\\n').split(' ')\n",
" label_dict[line[0]] = [int(p) for p in line[1:]]\n",
"\n",
" if split == 'train' or split == 'val':\n",
" # split training and validation data\n",
" usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n",
" random.seed(train_val_seed)\n",
" random.shuffle(usage_list)\n",
" percent = int(len(usage_list) * train_ratio)\n",
" usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]\n",
" elif split == 'test':\n",
" usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n",
" else:\n",
" raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n",
"\n",
" usage_list = [line.strip('\\n') for line in usage_list]\n",
" print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n",
"\n",
" max_len = 3000000\n",
" X = torch.empty(max_len, 39 * concat_nframes)\n",
" if mode != 'test':\n",
" y = torch.empty(max_len, dtype=torch.long)\n",
"\n",
" idx = 0\n",
" for i, fname in tqdm(enumerate(usage_list)):\n",
" feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n",
" cur_len = len(feat)\n",
" feat = concat_feat(feat, concat_nframes)\n",
" if mode != 'test':\n",
" label = torch.LongTensor(label_dict[fname])\n",
"\n",
" X[idx: idx + cur_len, :] = feat\n",
" if mode != 'test':\n",
" y[idx: idx + cur_len] = label\n",
"\n",
" idx += cur_len\n",
"\n",
" X = X[:idx, :]\n",
" if mode != 'test':\n",
" y = y[:idx]\n",
"\n",
" print(f'[INFO] {split} set')\n",
" print(X.shape)\n",
" if mode != 'test':\n",
" print(y.shape)\n",
" return X, y\n",
" else:\n",
" return X\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "us5XW_x6udZQ"
},
"source": [
"## Define Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fjf5EcmJtf4e"
},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"\n",
"class LibriDataset(Dataset):\n",
" def __init__(self, X, y=None):\n",
" self.data = X\n",
" if y is not None:\n",
" self.label = torch.LongTensor(y)\n",
" else:\n",
" self.label = None\n",
"\n",
" def __getitem__(self, idx):\n",
" if self.label is not None:\n",
" return self.data[idx], self.label[idx]\n",
" else:\n",
" return self.data[idx]\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IRqKNvNZwe3V"
},
"source": [
"## Define Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bg-GRd7ywdrL"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class BasicBlock(nn.Module):\n",
" def __init__(self, input_dim, output_dim):\n",
" super(BasicBlock, self).__init__()\n",
"\n",
" self.block = nn.Sequential(\n",
" nn.Linear(input_dim, output_dim),\n",
" nn.ReLU(),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.block(x)\n",
" return x\n",
"\n",
"\n",
"class Classifier(nn.Module):\n",
" def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n",
" super(Classifier, self).__init__()\n",
"\n",
" self.fc = nn.Sequential(\n",
" BasicBlock(input_dim, hidden_dim),\n",
" *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],\n",
" nn.Linear(hidden_dim, output_dim)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"source": [
"## Hyper-parameters"
],
"metadata": {
"id": "TlIq8JeqvvHC"
}
},
{
"cell_type": "code",
"source": [
"# data prarameters\n",
"concat_nframes = 1 # the number of frames to concat with, n must be odd (total 2k+1 = n frames)\n",
"train_ratio = 0.8 # the ratio of data used for training, the rest will be used for validation\n",
"\n",
"# training parameters\n",
"seed = 0 # random seed\n",
"batch_size = 512 # batch size\n",
"num_epoch = 5 # the number of training epoch\n",
"learning_rate = 0.0001 # learning rate\n",
"model_path = './model.ckpt' # the path where the checkpoint will be saved\n",
"\n",
"# model parameters\n",
"input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value\n",
"hidden_layers = 1 # the number of hidden layers\n",
"hidden_dim = 256 # the hidden dim"
],
"metadata": {
"id": "iIHn79Iav1ri"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Prepare dataset and model"
],
"metadata": {
"id": "IIUFRgG5yoDn"
}
},
{
"cell_type": "code",
"source": [
"import gc\n",
"\n",
"# preprocess data\n",
"train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
"val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
"\n",
"# get dataset\n",
"train_set = LibriDataset(train_X, train_y)\n",
"val_set = LibriDataset(val_X, val_y)\n",
"\n",
"# remove raw feature to save memory\n",
"del train_X, train_y, val_X, val_y\n",
"gc.collect()\n",
"\n",
"# get dataloader\n",
"train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)"
],
"metadata": {
"id": "c1zI3v5jyrDn",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7fd90470-ef44-404a-a043-cb8e83fd1801"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[Dataset] - # phone classes: 41, number of utterances for train: 3428\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"3428it [00:01, 2436.46it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[INFO] train set\n",
"torch.Size([2116368, 39])\n",
"torch.Size([2116368])\n",
"[Dataset] - # phone classes: 41, number of utterances for val: 858\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"858it [00:00, 2322.86it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[INFO] val set\n",
"torch.Size([527790, 39])\n",
"torch.Size([527790])\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CfRUEgC0GxUV",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "75dcb672-a97d-43ff-b0f1-cb23d27fe65a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"DEVICE: cuda:0\n"
]
}
],
"source": [
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"print(f'DEVICE: {device}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "88xPiUnm0tAd"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"#fix seed\n",
"def same_seeds(seed):\n",
" torch.manual_seed(seed)\n",
" if torch.cuda.is_available():\n",
" torch.cuda.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed) \n",
" np.random.seed(seed) \n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QTp3ZXg1yO9Y"
},
"outputs": [],
"source": [
"# fix random seed\n",
"same_seeds(seed)\n",
"\n",
"# create model, define a loss function, and optimizer\n",
"model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
"criterion = nn.CrossEntropyLoss() \n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "markdown",
"source": [
"## Training"
],
"metadata": {
"id": "pwWH1KIqzxEr"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CdMWsBs7zzNs",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "cfb0046d-52a5-4a90-c073-f4091e8b230e"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4134/4134 [00:23<00:00, 178.15it/s]\n",
"100%|██████████| 1031/1031 [00:03<00:00, 286.39it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[001/005] Train Acc: 0.335711 Loss: 2.508261 | Val Acc: 0.401501 loss: 2.131011\n",
"saving model with acc 0.402\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4134/4134 [00:23<00:00, 175.35it/s]\n",
"100%|██████████| 1031/1031 [00:03<00:00, 281.23it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[002/005] Train Acc: 0.380122 Loss: 2.270708 | Val Acc: 0.409386 loss: 2.085472\n",
"saving model with acc 0.409\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4134/4134 [00:23<00:00, 174.85it/s]\n",
"100%|██████████| 1031/1031 [00:03<00:00, 271.75it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[003/005] Train Acc: 0.388530 Loss: 2.227409 | Val Acc: 0.414295 loss: 2.063263\n",
"saving model with acc 0.414\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4134/4134 [00:23<00:00, 177.08it/s]\n",
"100%|██████████| 1031/1031 [00:03<00:00, 276.20it/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[004/005] Train Acc: 0.393600 Loss: 2.201657 | Val Acc: 0.417833 loss: 2.046759\n",
"saving model with acc 0.418\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4134/4134 [00:23<00:00, 177.41it/s]\n",
"100%|██████████| 1031/1031 [00:03<00:00, 280.50it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[005/005] Train Acc: 0.397076 Loss: 2.184807 | Val Acc: 0.420696 loss: 2.035235\n",
"saving model with acc 0.421\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
],
"source": [
"best_acc = 0.0\n",
"for epoch in range(num_epoch):\n",
" train_acc = 0.0\n",
" train_loss = 0.0\n",
" val_acc = 0.0\n",
" val_loss = 0.0\n",
" \n",
" # training\n",
" model.train() # set the model to training mode\n",
" for i, batch in enumerate(tqdm(train_loader)):\n",
" features, labels = batch\n",
" features = features.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" optimizer.zero_grad() \n",
" outputs = model(features) \n",
" \n",
" loss = criterion(outputs, labels)\n",
" loss.backward() \n",
" optimizer.step() \n",
" \n",
" _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
" train_acc += (train_pred.detach() == labels.detach()).sum().item()\n",
" train_loss += loss.item()\n",
" \n",
" # validation\n",
" if len(val_set) > 0:\n",
" model.eval() # set the model to evaluation mode\n",
" with torch.no_grad():\n",
" for i, batch in enumerate(tqdm(val_loader)):\n",
" features, labels = batch\n",
" features = features.to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(features)\n",
" \n",
" loss = criterion(outputs, labels) \n",
" \n",
" _, val_pred = torch.max(outputs, 1) \n",
" val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n",
" val_loss += loss.item()\n",
"\n",
" print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n",
" epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n",
" ))\n",
"\n",
" # if the model improves, save a checkpoint at this epoch\n",
" if val_acc > best_acc:\n",
" best_acc = val_acc\n",
" torch.save(model.state_dict(), model_path)\n",
" print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n",
" else:\n",
" print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n",
" epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n",
" ))\n",
"\n",
"# if not validating, save the last epoch\n",
"if len(val_set) == 0:\n",
" torch.save(model.state_dict(), model_path)\n",
" print('saving model at last epoch')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ab33MxosWLmG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7ea97815-15cc-4afa-fa7e-b65460b91640"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"50"
]
},
"metadata": {},
"execution_count": 12
}
],
"source": [
"del train_loader, val_loader\n",
"gc.collect()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1Hi7jTn3PX-m"
},
"source": [
"## Testing\n",
"Create a testing dataset, and load model from the saved checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VOG1Ou0PGrhc",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "81077fbc-a6ea-46b7-9a57-a690480fbb6b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[Dataset] - # phone classes: 41, number of utterances for test: 1078\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"1078it [00:00, 2784.86it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[INFO] test set\n",
"torch.Size([646268, 39])\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
],
"source": [
"# load data\n",
"test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n",
"test_set = LibriDataset(test_X, None)\n",
"test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ay0Fu8Ovkdad",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a29d1dbc-3222-4cec-8f84-04475b77cceb"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"# load model\n",
"model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
"model.load_state_dict(torch.load(model_path))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zp-DV1p4r7Nz"
},
"source": [
"Make prediction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "84HU5GGjPqR0",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bbaaa8c5-d88c-4ef3-f7be-d75b208cd5df"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 1263/1263 [00:02<00:00, 426.93it/s]\n"
]
}
],
"source": [
"test_acc = 0.0\n",
"test_lengths = 0\n",
"pred = np.array([], dtype=np.int32)\n",
"\n",
"model.eval()\n",
"with torch.no_grad():\n",
" for i, batch in enumerate(tqdm(test_loader)):\n",
" features = batch\n",
" features = features.to(device)\n",
"\n",
" outputs = model(features)\n",
"\n",
" _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
" pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wyZqy40Prz0v"
},
"source": [
"Write prediction to a CSV file.\n",
"\n",
"After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GuljYSPHcZir"
},
"outputs": [],
"source": [
"with open('prediction.csv', 'w') as f:\n",
" f.write('Id,Class\\n')\n",
" for i, y in enumerate(pred):\n",
" f.write('{},{}\\n'.format(i, y))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "ML2022Spring - HW2.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

BIN
2022 ML/02 What to do if my network fails to train/hw2_slides 2022.pdf View File


BIN
2022 ML/02 What to do if my network fails to train/theory (v7).pdf View File


+ 2
- 0
README.md View File

@@ -31,6 +31,7 @@ ppt/pdf支持直链下载。
|2021/12/20|更新Github排版,删除repo中的ppt/pdf直接提供下载链接,总资料放入[百度云盘-提取码:sv1i](https://pan.baidu.com/s/13cxyIbvF0bEyytANLf58NQ)|
|2022/02/17|2022春季机器学习课程仅在21基础上进行小补充,UP同步更新官网补充内容|
|2022/02/21|更新Lecture 1:Introductionof Deep Learning补充内容,Github排版大更新|
|2022/02/25|更新Lecture 2:What to do if my network fails to train补充内容与HW2|

****
@@ -55,5 +56,6 @@ ppt/pdf支持直链下载。
|章节|2021前置知识|2022补充|选修|作业|
|---|---|---|---|---|
|Lecture 1|[(上)机器学习基本概念简介](https://www.bilibili.com/video/BV1Wv411h7kN?p=3)<br/>[(下)机器学习基本概念简介](https://www.bilibili.com/video/BV1Wv411h7kN?p=4)|Video:<br/>[2022-机器学习相关规定](https://www.bilibili.com/video/BV1Wv411h7kN?p=1)<br/>[2022-Colab教学](https://www.bilibili.com/video/BV1Wv411h7kN?p=5)<br/>[2022-Pytorch Tutorial 1](https://www.bilibili.com/video/BV1Wv411h7kN?p=6)<br/>[2022-Pytorch Tutorial 2](https://www.bilibili.com/video/BV1Wv411h7kN?p=7)<br/><br/>PDF:<br/>[Rules](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/rule%20(v2).pdf)<br/>[Chinese class course intro](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/introduction%20(v2).pdf)<br/>[Pytorch Tutorial 1](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/Pytorch%20Tutorial%201.pdf)<br/>[Pytorch Tutorial 2](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/Pytorch%20Tutorial%202.pdf)<br/>[Colab Tutorial](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/Colab%20Tutorial%202022.pdf)<br/>[Environment Setup](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/EnvironmentSetup.pdf)|[深度学习简介](https://www.bilibili.com/video/BV1Wv411h7kN?p=13)<br/>[反向传播](https://www.bilibili.com/video/BV1Wv411h7kN?p=14)<br/>[预测-宝可梦](https://www.bilibili.com/video/BV1Wv411h7kN?p=15)<br/>[分类-宝可梦](https://www.bilibili.com/video/BV1Wv411h7kN?p=16)<br/>[逻辑回归](https://www.bilibili.com/video/BV1Wv411h7kN?p=17)|[Video](https://www.bilibili.com/video/BV1Wv411h7kN?p=11)<br/>[Slide](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/HW01.pdf)<br/>[Code](https://colab.research.google.com/drive/1FTcG6CE-HILnvFztEFKdauMlPKfQvm5Z#scrollTo=YdttVRkAfu2t)<br/>[Submission](https://www.kaggle.com/t/a3ebd5b5542f0f55e828d4f00de8e59a)|
|Lecture 2|[(一)局部最小值 (local minima) 与鞍点 (saddle point)](https://www.bilibili.com/video/BV1Wv411h7kN?p=20)<br/>[(二)批次 (batch) 与动量 (momentum)](https://www.bilibili.com/video/BV1Wv411h7kN?p=21)<br/>[(三)自动调整学习率 (Learning Rate)](https://www.bilibili.com/video/BV1Wv411h7kN?p=22)<br/>[(四)损失函数 (Loss) 也可能有影响](https://www.bilibili.com/video/BV1Wv411h7kN?p=23)|Video:<br/>[2022-再探宝可梦、数码宝贝分类器 — 浅谈机器学习原理](https://www.bilibili.com/video/BV1Wv411h7kN?p=19)<br/><br/>PDF:<br/>[Theory](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/theory%20(v7).pdf)|[Gradient Descent (Demo by AOE)](https://www.bilibili.com/video/BV1Wv411h7kN?p=24)<br/>[ Beyond Adam (part 1)](https://www.bilibili.com/video/BV1Wv411h7kN?p=26)<br/>[ Beyond Adam (part 2)](https://www.bilibili.com/video/BV1Wv411h7kN?p=27)|Video<br/>[Slide](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2022-course-data/hw2_slides%202022.pdf)<br/>[Code](https://colab.research.google.com/drive/1hmTFJ8hdcnqRz_0oJSXjTGhZLVU-bS1a?usp=sharing)<br/>[Submission](https://www.kaggle.com/c/ml2022spring-hw2)|

****

Loading…
Cancel
Save