|
|
@@ -0,0 +1,773 @@ |
|
|
|
{ |
|
|
|
"nbformat": 4, |
|
|
|
"nbformat_minor": 0, |
|
|
|
"metadata": { |
|
|
|
"accelerator": "GPU", |
|
|
|
"colab": { |
|
|
|
"name": "HW04.ipynb", |
|
|
|
"provenance": [], |
|
|
|
"collapsed_sections": [], |
|
|
|
"toc_visible": true |
|
|
|
}, |
|
|
|
"kernelspec": { |
|
|
|
"display_name": "Python 3", |
|
|
|
"name": "python3" |
|
|
|
} |
|
|
|
}, |
|
|
|
"cells": [ |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "zC5KwRyl6Flp" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Task description\n", |
|
|
|
"- Classify the speakers of given features.\n", |
|
|
|
"- Main goal: Learn how to use transformer.\n", |
|
|
|
"- Baselines:\n", |
|
|
|
" - Easy: Run sample code and know how to use transformer.\n", |
|
|
|
" - Medium: Know how to adjust parameters of transformer.\n", |
|
|
|
" - Hard: Construct [conformer](https://arxiv.org/abs/2005.08100) which is a variety of transformer. " |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "TPDoreyypeJE" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Download dataset" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "QvpaILXnJIcw" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"\"\"\"\n", |
|
|
|
" For Google drive, You can download data form any link below.\n", |
|
|
|
" If a link fails, please use another one.\n", |
|
|
|
"\"\"\"\n", |
|
|
|
"\"\"\" Download link 1 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '1T0RPnu-Sg5eIPwQPfYysipfcz81MnsYe' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 2 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '1CtHZhJ-mTpNsO-MqvAPIi4Yrt3oSBXYV' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 3 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '14hmoMgB1fe6v50biIceKyndyeYABGrRq' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 4 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '1e9x-Pjl3n7-9tK9LS_WjiMo2lru4UBH9' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 5 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '10TC0g46bcAz_jkiMl65zNmwttT4RiRgY' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 6 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '1MUGBvG_JjqO0C2JYHuyV3B0lvaf1kWIm' --output Dataset.zip\n", |
|
|
|
"\"\"\" Download link 7 of Google drive \"\"\"\n", |
|
|
|
"# !gdown --id '18M91P5DHwILNyOlssZ57AiPOR0OwutOM' --output Dataset.zip\n", |
|
|
|
"\"\"\" For Google drive, you can unzip the data by the command below. \"\"\"\n", |
|
|
|
"# !unzip Dataset.zip\n", |
|
|
|
"\n", |
|
|
|
"\"\"\"\n", |
|
|
|
" For Dropbox, we split dataset into five files. \n", |
|
|
|
" Please download all of them.\n", |
|
|
|
"\"\"\"\n", |
|
|
|
"# If Dropbox is not work. Please use google drive.\n", |
|
|
|
"# !wget https://www.dropbox.com/s/vw324newiku0sz0/Dataset.tar.gz.aa?dl=0\n", |
|
|
|
"# !wget https://www.dropbox.com/s/z840g69e7lnkayo/Dataset.tar.gz.ab?dl=0\n", |
|
|
|
"# !wget https://www.dropbox.com/s/hl081e1ggonio81/Dataset.tar.gz.ac?dl=0\n", |
|
|
|
"# !wget https://www.dropbox.com/s/fh3zd8ow668c4th/Dataset.tar.gz.ad?dl=0\n", |
|
|
|
"# !wget https://www.dropbox.com/s/ydzygoy2pv6gw9d/Dataset.tar.gz.ae?dl=0\n", |
|
|
|
"# !cat Dataset.tar.gz.* | tar zxvf -\n", |
|
|
|
"\n", |
|
|
|
"\"\"\"\n", |
|
|
|
" For Onedrive, we split dataset into five files. \n", |
|
|
|
" Please download all of them.\n", |
|
|
|
"\"\"\"\n", |
|
|
|
"!wget --no-check-certificate \"https://onedrive.live.com/download?cid=10C95EE5FD151BFB&resid=10C95EE5FD151BFB%21106&authkey=ACB6opQR3CG9kmc\" -O Dataset.tar.gz.aa\n", |
|
|
|
"!wget --no-check-certificate \"https://onedrive.live.com/download?cid=93DDDDD552E145DB&resid=93DDDDD552E145DB%21106&authkey=AP6EepjxSdvyV6Y\" -O Dataset.tar.gz.ab\n", |
|
|
|
"!wget --no-check-certificate \"https://onedrive.live.com/download?cid=644545816461BCCC&resid=644545816461BCCC%21106&authkey=ALiefB0kI7Epb0Q\" -O Dataset.tar.gz.ac\n", |
|
|
|
"!wget --no-check-certificate \"https://onedrive.live.com/download?cid=77CEBB3C3C512821&resid=77CEBB3C3C512821%21106&authkey=AAXCx4TTDYC0yjM\" -O Dataset.tar.gz.ad\n", |
|
|
|
"!wget --no-check-certificate \"https://onedrive.live.com/download?cid=383D0E0146A11B02&resid=383D0E0146A11B02%21106&authkey=ALwVc4StVbig6QI\" -O Dataset.tar.gz.ae\n", |
|
|
|
"!cat Dataset.tar.gz.* | tar zxvf -" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "v1gYr_aoNDue" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Data" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "Mz_NpuAipk3h" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"## Dataset\n", |
|
|
|
"- Original dataset is [Voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/).\n", |
|
|
|
"- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb1.\n", |
|
|
|
"- We randomly select 600 speakers from Voxceleb1.\n", |
|
|
|
"- Then preprocess the raw waveforms into mel-spectrograms.\n", |
|
|
|
"\n", |
|
|
|
"- Args:\n", |
|
|
|
" - data_dir: The path to the data directory.\n", |
|
|
|
" - metadata_path: The path to the metadata.\n", |
|
|
|
" - segment_len: The length of audio segment for training. \n", |
|
|
|
"- The architecture of data directory \\\\\n", |
|
|
|
" - data directory \\\\\n", |
|
|
|
" |---- metadata.json \\\\\n", |
|
|
|
" |---- testdata.json \\\\\n", |
|
|
|
" |---- mapping.json \\\\\n", |
|
|
|
" |---- uttr-{random string}.pt \\\\\n", |
|
|
|
"\n", |
|
|
|
"- The information in metadata\n", |
|
|
|
" - \"n_mels\": The dimention of mel-spectrogram.\n", |
|
|
|
" - \"speakers\": A dictionary. \n", |
|
|
|
" - Key: speaker ids.\n", |
|
|
|
" - value: \"feature_path\" and \"mel_len\"\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"For efficiency, we segment the mel-spectrograms into segments in the traing step." |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "cd7hoGhYtbXQ" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import os\n", |
|
|
|
"import json\n", |
|
|
|
"import torch\n", |
|
|
|
"import random\n", |
|
|
|
"from pathlib import Path\n", |
|
|
|
"from torch.utils.data import Dataset\n", |
|
|
|
"from torch.nn.utils.rnn import pad_sequence\n", |
|
|
|
" \n", |
|
|
|
" \n", |
|
|
|
"class myDataset(Dataset):\n", |
|
|
|
" def __init__(self, data_dir, segment_len=128):\n", |
|
|
|
" self.data_dir = data_dir\n", |
|
|
|
" self.segment_len = segment_len\n", |
|
|
|
" \n", |
|
|
|
" # Load the mapping from speaker neme to their corresponding id. \n", |
|
|
|
" mapping_path = Path(data_dir) / \"mapping.json\"\n", |
|
|
|
" mapping = json.load(mapping_path.open())\n", |
|
|
|
" self.speaker2id = mapping[\"speaker2id\"]\n", |
|
|
|
" \n", |
|
|
|
" # Load metadata of training data.\n", |
|
|
|
" metadata_path = Path(data_dir) / \"metadata.json\"\n", |
|
|
|
" metadata = json.load(open(metadata_path))[\"speakers\"]\n", |
|
|
|
" \n", |
|
|
|
" # Get the total number of speaker.\n", |
|
|
|
" self.speaker_num = len(metadata.keys())\n", |
|
|
|
" self.data = []\n", |
|
|
|
" for speaker in metadata.keys():\n", |
|
|
|
" for utterances in metadata[speaker]:\n", |
|
|
|
" self.data.append([utterances[\"feature_path\"], self.speaker2id[speaker]])\n", |
|
|
|
" \n", |
|
|
|
" def __len__(self):\n", |
|
|
|
" return len(self.data)\n", |
|
|
|
" \n", |
|
|
|
" def __getitem__(self, index):\n", |
|
|
|
" feat_path, speaker = self.data[index]\n", |
|
|
|
" # Load preprocessed mel-spectrogram.\n", |
|
|
|
" mel = torch.load(os.path.join(self.data_dir, feat_path))\n", |
|
|
|
" \n", |
|
|
|
" # Segmemt mel-spectrogram into \"segment_len\" frames.\n", |
|
|
|
" if len(mel) > self.segment_len:\n", |
|
|
|
" # Randomly get the starting point of the segment.\n", |
|
|
|
" start = random.randint(0, len(mel) - self.segment_len)\n", |
|
|
|
" # Get a segment with \"segment_len\" frames.\n", |
|
|
|
" mel = torch.FloatTensor(mel[start:start+self.segment_len])\n", |
|
|
|
" else:\n", |
|
|
|
" mel = torch.FloatTensor(mel)\n", |
|
|
|
" # Turn the speaker id into long for computing loss later.\n", |
|
|
|
" speaker = torch.FloatTensor([speaker]).long()\n", |
|
|
|
" return mel, speaker\n", |
|
|
|
" \n", |
|
|
|
" def get_speaker_number(self):\n", |
|
|
|
" return self.speaker_num" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "mqJxjoi_NGnB" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"## Dataloader\n", |
|
|
|
"- Split dataset into training dataset(90%) and validation dataset(10%).\n", |
|
|
|
"- Create dataloader to iterate the data.\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "zuT1AuFENI8t" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import torch\n", |
|
|
|
"from torch.utils.data import DataLoader, random_split\n", |
|
|
|
"from torch.nn.utils.rnn import pad_sequence\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def collate_batch(batch):\n", |
|
|
|
" # Process features within a batch.\n", |
|
|
|
" \"\"\"Collate a batch of data.\"\"\"\n", |
|
|
|
" mel, speaker = zip(*batch)\n", |
|
|
|
" # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.\n", |
|
|
|
" mel = pad_sequence(mel, batch_first=True, padding_value=-20) # pad log 10^(-20) which is very small value.\n", |
|
|
|
" # mel: (batch size, length, 40)\n", |
|
|
|
" return mel, torch.FloatTensor(speaker).long()\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def get_dataloader(data_dir, batch_size, n_workers):\n", |
|
|
|
" \"\"\"Generate dataloader\"\"\"\n", |
|
|
|
" dataset = myDataset(data_dir)\n", |
|
|
|
" speaker_num = dataset.get_speaker_number()\n", |
|
|
|
" # Split dataset into training dataset and validation dataset\n", |
|
|
|
" trainlen = int(0.9 * len(dataset))\n", |
|
|
|
" lengths = [trainlen, len(dataset) - trainlen]\n", |
|
|
|
" trainset, validset = random_split(dataset, lengths)\n", |
|
|
|
"\n", |
|
|
|
" train_loader = DataLoader(\n", |
|
|
|
" trainset,\n", |
|
|
|
" batch_size=batch_size,\n", |
|
|
|
" shuffle=True,\n", |
|
|
|
" drop_last=True,\n", |
|
|
|
" num_workers=n_workers,\n", |
|
|
|
" pin_memory=True,\n", |
|
|
|
" collate_fn=collate_batch,\n", |
|
|
|
" )\n", |
|
|
|
" valid_loader = DataLoader(\n", |
|
|
|
" validset,\n", |
|
|
|
" batch_size=batch_size,\n", |
|
|
|
" num_workers=n_workers,\n", |
|
|
|
" drop_last=True,\n", |
|
|
|
" pin_memory=True,\n", |
|
|
|
" collate_fn=collate_batch,\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" return train_loader, valid_loader, speaker_num\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "X0x6eXiHpr4R" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Model\n", |
|
|
|
"- TransformerEncoderLayer:\n", |
|
|
|
" - Base transformer encoder layer in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n", |
|
|
|
" - Parameters:\n", |
|
|
|
" - d_model: the number of expected features of the input (required).\n", |
|
|
|
"\n", |
|
|
|
" - nhead: the number of heads of the multiheadattention models (required).\n", |
|
|
|
"\n", |
|
|
|
" - dim_feedforward: the dimension of the feedforward network model (default=2048).\n", |
|
|
|
"\n", |
|
|
|
" - dropout: the dropout value (default=0.1).\n", |
|
|
|
"\n", |
|
|
|
" - activation: the activation function of intermediate layer, relu or gelu (default=relu).\n", |
|
|
|
"\n", |
|
|
|
"- TransformerEncoder:\n", |
|
|
|
" - TransformerEncoder is a stack of N transformer encoder layers\n", |
|
|
|
" - Parameters:\n", |
|
|
|
" - encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n", |
|
|
|
"\n", |
|
|
|
" - num_layers: the number of sub-encoder-layers in the encoder (required).\n", |
|
|
|
"\n", |
|
|
|
" - norm: the layer normalization component (optional)." |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "SHX4eVj4tjtd" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import torch\n", |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"import torch.nn.functional as F\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"class Classifier(nn.Module):\n", |
|
|
|
" def __init__(self, d_model=80, n_spks=600, dropout=0.1):\n", |
|
|
|
" super().__init__()\n", |
|
|
|
" # Project the dimension of features from that of input into d_model.\n", |
|
|
|
" self.prenet = nn.Linear(40, d_model)\n", |
|
|
|
" # TODO:\n", |
|
|
|
" # Change Transformer to Conformer.\n", |
|
|
|
" # https://arxiv.org/abs/2005.08100\n", |
|
|
|
" self.encoder_layer = nn.TransformerEncoderLayer(\n", |
|
|
|
" d_model=d_model, dim_feedforward=256, nhead=2\n", |
|
|
|
" )\n", |
|
|
|
" # self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)\n", |
|
|
|
"\n", |
|
|
|
" # Project the the dimension of features from d_model into speaker nums.\n", |
|
|
|
" self.pred_layer = nn.Sequential(\n", |
|
|
|
" nn.Linear(d_model, d_model),\n", |
|
|
|
" nn.ReLU(),\n", |
|
|
|
" nn.Linear(d_model, n_spks),\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" def forward(self, mels):\n", |
|
|
|
" \"\"\"\n", |
|
|
|
" args:\n", |
|
|
|
" mels: (batch size, length, 40)\n", |
|
|
|
" return:\n", |
|
|
|
" out: (batch size, n_spks)\n", |
|
|
|
" \"\"\"\n", |
|
|
|
" # out: (batch size, length, d_model)\n", |
|
|
|
" out = self.prenet(mels)\n", |
|
|
|
" # out: (length, batch size, d_model)\n", |
|
|
|
" out = out.permute(1, 0, 2)\n", |
|
|
|
" # The encoder layer expect features in the shape of (length, batch size, d_model).\n", |
|
|
|
" out = self.encoder_layer(out)\n", |
|
|
|
" # out: (batch size, length, d_model)\n", |
|
|
|
" out = out.transpose(0, 1)\n", |
|
|
|
" # mean pooling\n", |
|
|
|
" stats = out.mean(dim=1)\n", |
|
|
|
"\n", |
|
|
|
" # out: (batch, n_spks)\n", |
|
|
|
" out = self.pred_layer(stats)\n", |
|
|
|
" return out\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "-__DolPGpvDZ" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Learning rate schedule\n", |
|
|
|
"- For transformer architecture, the design of learning rate schedule is different from that of CNN.\n", |
|
|
|
"- Previous works show that the warmup of learning rate is useful for training models with transformer architectures.\n", |
|
|
|
"- The warmup schedule\n", |
|
|
|
" - Set learning rate to 0 in the beginning.\n", |
|
|
|
" - The learning rate increases linearly from 0 to initial learning rate during warmup period." |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "K-0816BntqT9" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import math\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch.optim import Optimizer\n", |
|
|
|
"from torch.optim.lr_scheduler import LambdaLR\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def get_cosine_schedule_with_warmup(\n", |
|
|
|
" optimizer: Optimizer,\n", |
|
|
|
" num_warmup_steps: int,\n", |
|
|
|
" num_training_steps: int,\n", |
|
|
|
" num_cycles: float = 0.5,\n", |
|
|
|
" last_epoch: int = -1,\n", |
|
|
|
"):\n", |
|
|
|
" \"\"\"\n", |
|
|
|
" Create a schedule with a learning rate that decreases following the values of the cosine function between the\n", |
|
|
|
" initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n", |
|
|
|
" initial lr set in the optimizer.\n", |
|
|
|
"\n", |
|
|
|
" Args:\n", |
|
|
|
" optimizer (:class:`~torch.optim.Optimizer`):\n", |
|
|
|
" The optimizer for which to schedule the learning rate.\n", |
|
|
|
" num_warmup_steps (:obj:`int`):\n", |
|
|
|
" The number of steps for the warmup phase.\n", |
|
|
|
" num_training_steps (:obj:`int`):\n", |
|
|
|
" The total number of training steps.\n", |
|
|
|
" num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n", |
|
|
|
" The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n", |
|
|
|
" following a half-cosine).\n", |
|
|
|
" last_epoch (:obj:`int`, `optional`, defaults to -1):\n", |
|
|
|
" The index of the last epoch when resuming training.\n", |
|
|
|
"\n", |
|
|
|
" Return:\n", |
|
|
|
" :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n", |
|
|
|
" \"\"\"\n", |
|
|
|
"\n", |
|
|
|
" def lr_lambda(current_step):\n", |
|
|
|
" # Warmup\n", |
|
|
|
" if current_step < num_warmup_steps:\n", |
|
|
|
" return float(current_step) / float(max(1, num_warmup_steps))\n", |
|
|
|
" # decadence\n", |
|
|
|
" progress = float(current_step - num_warmup_steps) / float(\n", |
|
|
|
" max(1, num_training_steps - num_warmup_steps)\n", |
|
|
|
" )\n", |
|
|
|
" return max(\n", |
|
|
|
" 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" return LambdaLR(optimizer, lr_lambda, last_epoch)\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "IP03FFo9K8DS" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Model Function\n", |
|
|
|
"- Model forward function." |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "fohaLEFJK9-t" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import torch\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def model_fn(batch, model, criterion, device):\n", |
|
|
|
" \"\"\"Forward a batch through the model.\"\"\"\n", |
|
|
|
"\n", |
|
|
|
" mels, labels = batch\n", |
|
|
|
" mels = mels.to(device)\n", |
|
|
|
" labels = labels.to(device)\n", |
|
|
|
"\n", |
|
|
|
" outs = model(mels)\n", |
|
|
|
"\n", |
|
|
|
" loss = criterion(outs, labels)\n", |
|
|
|
"\n", |
|
|
|
" # Get the speaker id with highest probability.\n", |
|
|
|
" preds = outs.argmax(1)\n", |
|
|
|
" # Compute accuracy.\n", |
|
|
|
" accuracy = torch.mean((preds == labels).float())\n", |
|
|
|
"\n", |
|
|
|
" return loss, accuracy\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "F7cg-YrzLQcf" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Validate\n", |
|
|
|
"- Calculate accuracy of the validation set." |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "mD-_p6nWLO2L" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"from tqdm import tqdm\n", |
|
|
|
"import torch\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def valid(dataloader, model, criterion, device): \n", |
|
|
|
" \"\"\"Validate on validation set.\"\"\"\n", |
|
|
|
"\n", |
|
|
|
" model.eval()\n", |
|
|
|
" running_loss = 0.0\n", |
|
|
|
" running_accuracy = 0.0\n", |
|
|
|
" pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc=\"Valid\", unit=\" uttr\")\n", |
|
|
|
"\n", |
|
|
|
" for i, batch in enumerate(dataloader):\n", |
|
|
|
" with torch.no_grad():\n", |
|
|
|
" loss, accuracy = model_fn(batch, model, criterion, device)\n", |
|
|
|
" running_loss += loss.item()\n", |
|
|
|
" running_accuracy += accuracy.item()\n", |
|
|
|
"\n", |
|
|
|
" pbar.update(dataloader.batch_size)\n", |
|
|
|
" pbar.set_postfix(\n", |
|
|
|
" loss=f\"{running_loss / (i+1):.2f}\",\n", |
|
|
|
" accuracy=f\"{running_accuracy / (i+1):.2f}\",\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" pbar.close()\n", |
|
|
|
" model.train()\n", |
|
|
|
"\n", |
|
|
|
" return running_accuracy / len(dataloader)\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "noHXyal5p1W5" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Main function" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "chRQE7oYtw62" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"from tqdm import tqdm\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"from torch.optim import AdamW\n", |
|
|
|
"from torch.utils.data import DataLoader, random_split\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def parse_args():\n", |
|
|
|
" \"\"\"arguments\"\"\"\n", |
|
|
|
" config = {\n", |
|
|
|
" \"data_dir\": \"./Dataset\",\n", |
|
|
|
" \"save_path\": \"model.ckpt\",\n", |
|
|
|
" \"batch_size\": 32,\n", |
|
|
|
" \"n_workers\": 8,\n", |
|
|
|
" \"valid_steps\": 2000,\n", |
|
|
|
" \"warmup_steps\": 1000,\n", |
|
|
|
" \"save_steps\": 10000,\n", |
|
|
|
" \"total_steps\": 70000,\n", |
|
|
|
" }\n", |
|
|
|
"\n", |
|
|
|
" return config\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def main(\n", |
|
|
|
" data_dir,\n", |
|
|
|
" save_path,\n", |
|
|
|
" batch_size,\n", |
|
|
|
" n_workers,\n", |
|
|
|
" valid_steps,\n", |
|
|
|
" warmup_steps,\n", |
|
|
|
" total_steps,\n", |
|
|
|
" save_steps,\n", |
|
|
|
"):\n", |
|
|
|
" \"\"\"Main function.\"\"\"\n", |
|
|
|
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
|
" print(f\"[Info]: Use {device} now!\")\n", |
|
|
|
"\n", |
|
|
|
" train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)\n", |
|
|
|
" train_iterator = iter(train_loader)\n", |
|
|
|
" print(f\"[Info]: Finish loading data!\",flush = True)\n", |
|
|
|
"\n", |
|
|
|
" model = Classifier(n_spks=speaker_num).to(device)\n", |
|
|
|
" criterion = nn.CrossEntropyLoss()\n", |
|
|
|
" optimizer = AdamW(model.parameters(), lr=1e-3)\n", |
|
|
|
" scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n", |
|
|
|
" print(f\"[Info]: Finish creating model!\",flush = True)\n", |
|
|
|
"\n", |
|
|
|
" best_accuracy = -1.0\n", |
|
|
|
" best_state_dict = None\n", |
|
|
|
"\n", |
|
|
|
" pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", |
|
|
|
"\n", |
|
|
|
" for step in range(total_steps):\n", |
|
|
|
" # Get data\n", |
|
|
|
" try:\n", |
|
|
|
" batch = next(train_iterator)\n", |
|
|
|
" except StopIteration:\n", |
|
|
|
" train_iterator = iter(train_loader)\n", |
|
|
|
" batch = next(train_iterator)\n", |
|
|
|
"\n", |
|
|
|
" loss, accuracy = model_fn(batch, model, criterion, device)\n", |
|
|
|
" batch_loss = loss.item()\n", |
|
|
|
" batch_accuracy = accuracy.item()\n", |
|
|
|
"\n", |
|
|
|
" # Updata model\n", |
|
|
|
" loss.backward()\n", |
|
|
|
" optimizer.step()\n", |
|
|
|
" scheduler.step()\n", |
|
|
|
" optimizer.zero_grad()\n", |
|
|
|
" \n", |
|
|
|
" # Log\n", |
|
|
|
" pbar.update()\n", |
|
|
|
" pbar.set_postfix(\n", |
|
|
|
" loss=f\"{batch_loss:.2f}\",\n", |
|
|
|
" accuracy=f\"{batch_accuracy:.2f}\",\n", |
|
|
|
" step=step + 1,\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" # Do validation\n", |
|
|
|
" if (step + 1) % valid_steps == 0:\n", |
|
|
|
" pbar.close()\n", |
|
|
|
"\n", |
|
|
|
" valid_accuracy = valid(valid_loader, model, criterion, device)\n", |
|
|
|
"\n", |
|
|
|
" # keep the best model\n", |
|
|
|
" if valid_accuracy > best_accuracy:\n", |
|
|
|
" best_accuracy = valid_accuracy\n", |
|
|
|
" best_state_dict = model.state_dict()\n", |
|
|
|
"\n", |
|
|
|
" pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n", |
|
|
|
"\n", |
|
|
|
" # Save the best model so far.\n", |
|
|
|
" if (step + 1) % save_steps == 0 and best_state_dict is not None:\n", |
|
|
|
" torch.save(best_state_dict, save_path)\n", |
|
|
|
" pbar.write(f\"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})\")\n", |
|
|
|
"\n", |
|
|
|
" pbar.close()\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"if __name__ == \"__main__\":\n", |
|
|
|
" main(**parse_args())\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "0R2rx3AyHpQ-" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# Inference" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "pSuI3WY9Fz78" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"## Dataset of inference" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "4evns0055Dsx" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import os\n", |
|
|
|
"import json\n", |
|
|
|
"import torch\n", |
|
|
|
"from pathlib import Path\n", |
|
|
|
"from torch.utils.data import Dataset\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"class InferenceDataset(Dataset):\n", |
|
|
|
" def __init__(self, data_dir):\n", |
|
|
|
" testdata_path = Path(data_dir) / \"testdata.json\"\n", |
|
|
|
" metadata = json.load(testdata_path.open())\n", |
|
|
|
" self.data_dir = data_dir\n", |
|
|
|
" self.data = metadata[\"utterances\"]\n", |
|
|
|
"\n", |
|
|
|
" def __len__(self):\n", |
|
|
|
" return len(self.data)\n", |
|
|
|
"\n", |
|
|
|
" def __getitem__(self, index):\n", |
|
|
|
" utterance = self.data[index]\n", |
|
|
|
" feat_path = utterance[\"feature_path\"]\n", |
|
|
|
" mel = torch.load(os.path.join(self.data_dir, feat_path))\n", |
|
|
|
"\n", |
|
|
|
" return feat_path, mel\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def inference_collate_batch(batch):\n", |
|
|
|
" \"\"\"Collate a batch of data.\"\"\"\n", |
|
|
|
" feat_paths, mels = zip(*batch)\n", |
|
|
|
"\n", |
|
|
|
" return feat_paths, torch.stack(mels)\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "markdown", |
|
|
|
"metadata": { |
|
|
|
"id": "oAinHBG1GIWv" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"## Main funcrion of Inference" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
|
"cell_type": "code", |
|
|
|
"metadata": { |
|
|
|
"id": "yQaTt7VDHoRI" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import json\n", |
|
|
|
"import csv\n", |
|
|
|
"from pathlib import Path\n", |
|
|
|
"from tqdm.notebook import tqdm\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"from torch.utils.data import DataLoader\n", |
|
|
|
"\n", |
|
|
|
"def parse_args():\n", |
|
|
|
" \"\"\"arguments\"\"\"\n", |
|
|
|
" config = {\n", |
|
|
|
" \"data_dir\": \"./Dataset\",\n", |
|
|
|
" \"model_path\": \"./model.ckpt\",\n", |
|
|
|
" \"output_path\": \"./output.csv\",\n", |
|
|
|
" }\n", |
|
|
|
"\n", |
|
|
|
" return config\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"def main(\n", |
|
|
|
" data_dir,\n", |
|
|
|
" model_path,\n", |
|
|
|
" output_path,\n", |
|
|
|
"):\n", |
|
|
|
" \"\"\"Main function.\"\"\"\n", |
|
|
|
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |
|
|
|
" print(f\"[Info]: Use {device} now!\")\n", |
|
|
|
"\n", |
|
|
|
" mapping_path = Path(data_dir) / \"mapping.json\"\n", |
|
|
|
" mapping = json.load(mapping_path.open())\n", |
|
|
|
"\n", |
|
|
|
" dataset = InferenceDataset(data_dir)\n", |
|
|
|
" dataloader = DataLoader(\n", |
|
|
|
" dataset,\n", |
|
|
|
" batch_size=1,\n", |
|
|
|
" shuffle=False,\n", |
|
|
|
" drop_last=False,\n", |
|
|
|
" num_workers=8,\n", |
|
|
|
" collate_fn=inference_collate_batch,\n", |
|
|
|
" )\n", |
|
|
|
" print(f\"[Info]: Finish loading data!\",flush = True)\n", |
|
|
|
"\n", |
|
|
|
" speaker_num = len(mapping[\"id2speaker\"])\n", |
|
|
|
" model = Classifier(n_spks=speaker_num).to(device)\n", |
|
|
|
" model.load_state_dict(torch.load(model_path))\n", |
|
|
|
" model.eval()\n", |
|
|
|
" print(f\"[Info]: Finish creating model!\",flush = True)\n", |
|
|
|
"\n", |
|
|
|
" results = [[\"Id\", \"Category\"]]\n", |
|
|
|
" for feat_paths, mels in tqdm(dataloader):\n", |
|
|
|
" with torch.no_grad():\n", |
|
|
|
" mels = mels.to(device)\n", |
|
|
|
" outs = model(mels)\n", |
|
|
|
" preds = outs.argmax(1).cpu().numpy()\n", |
|
|
|
" for feat_path, pred in zip(feat_paths, preds):\n", |
|
|
|
" results.append([feat_path, mapping[\"id2speaker\"][str(pred)]])\n", |
|
|
|
" \n", |
|
|
|
" with open(output_path, 'w', newline='') as csvfile:\n", |
|
|
|
" writer = csv.writer(csvfile)\n", |
|
|
|
" writer.writerows(results)\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"if __name__ == \"__main__\":\n", |
|
|
|
" main(**parse_args())\n" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|
"outputs": [] |
|
|
|
} |
|
|
|
] |
|
|
|
} |