diff --git a/01 Introduction/Google_Colab_Tutorial.ipynb b/01 Introduction/Google_Colab_Tutorial.ipynb new file mode 100644 index 0000000..0257bc1 --- /dev/null +++ b/01 Introduction/Google_Colab_Tutorial.ipynb @@ -0,0 +1,299 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Google Colab Tutorial", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ca2CpPPUvO-h" + }, + "source": [ + "# **Google Colab Tutorial**\n", + "\n", + "\n", + "Should you have any question, contact TA via
ntu-ml-2021spring-ta@googlegroups.com\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xIN7RF4wjgHk" + }, + "source": [ + "

\"Colaboratory

\r\n", + "\r\n", + "

What is Colaboratory?

\r\n", + "\r\n", + "Colaboratory, or \"Colab\" for short, allows you to write and execute Python in your browser, with \r\n", + "- Zero configuration required\r\n", + "- Free access to GPUs\r\n", + "- Easy sharing\r\n", + "\r\n", + "Whether you're a **student**, a **data scientist** or an **AI researcher**, Colab can make your work easier. Watch [Introduction to Colab](https://www.youtube.com/watch?v=inN8seMm7UI) to learn more, or just get started below!\r\n", + "\r\n", + "You can type python code in the code block, or use a leading exclamation mark ! to change the code block to bash environment to execute linux code." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IrAxlhp3VBMD" + }, + "source": [ + "To utilize the free GPU provided by google, click on \"Runtime\"(執行階段) -> \"Change Runtime Type\"(變更執行階段類型). There are three options under \"Hardward Accelerator\"(硬體加速器), select \"GPU\". \r\n", + "* Doing this will restart the session, so make sure you change to the desired runtime before executing any code.\r\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CLUWxZKbvQpx" + }, + "source": [ + "import torch\n", + "torch.cuda.is_available() # is GPU available\n", + "# Outputs True if running with GPU" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EAM_tPQAELh0" + }, + "source": [ + "**1. Download Files via google drive**\n", + "\n", + " A file stored in Google Drive has the following sharing link:\n", + "\n", + " https://drive.google.com/open?id=1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\n", + " \n", + " The random string after \"open?id=\" is the **file_id**
\n", + "![](https://i.imgur.com/33SW1WZ.png)\n", + "\n", + " It is possible to download the file via Colab knowing the **file_id**, using the following command.\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "XztYEj0oD7J3" + }, + "source": [ + "# Download the file with file_id \"1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\", and rename it to Minori.jpg\n", + "!gdown --id '1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV' --output Minori.jpg" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Gg3T23LXG-eL" + }, + "source": [ + "# List all the files under the working directory\n", + "!ls" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "38dcGQujOVWM" + }, + "source": [ + "Exclamation mark (!) starts a new shell, does the operations, and then kills that shell, while percentage (%) affects the process associated with the notebook" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dOQxjfAZAsys" + }, + "source": [ + "It can be seen that `Minori.jpg` is saved the the current working directory. \r\n", + "\r\n", + "The working space is temporary, once you close the browser, the file will be gone.\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wLUPcHuNHF8u" + }, + "source": [ + "Clicking on the folder icon will give you the visuallization of the file structure\n", + "
\n", + "  ![image.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MXp98PijHkrk" + }, + "source": [ + "There should be a file named `Minori.jpg`, if you do not see it, click the icon in the middle (refresh button)
\n", + "  ![](https://i.imgur.com/CNBTH23.png)\n", + "
\n", + "You can double click on the file to view the image.\n", + "\n", + "\n", + "   \n", + "![](https://i.imgur.com/h2PLMrq.png)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k_gmTo9NKtu9" + }, + "source": [ + "**2. Mounting Google Drive**\n", + "\n", + " One advantage of using google colab is that connection with other google services such as Google Drive is simple. By mounting google drive, the working files can be stored permanantly. After executing the following code block, log in to the google account and copy the authentication code to the input box to finish the process." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ImETTQKkL2l4" + }, + "source": [ + "from google.colab import drive # Import a library named google.colab\n", + "drive.mount('/content/drive', force_remount=True) # mount the content to the directory `/content/drive`" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BmvzTF5IJ6TL" + }, + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AkmayCmGMD03" + }, + "source": [ + "After mounting the drive, the content of the google drive will be under a directory named `MyDrive`, check the file structure for such a folder to confirm the execution of the code." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kDrO_DjBMW5D" + }, + "source": [ + "There is also an icon for mounting google drive. The icon will automatically generate the code above.\n", + "\n", + "![](https://i.imgur.com/hM9Jgi7.png) \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UhKhwipoMvXF" + }, + "source": [ + "After mounting the drive, all the chnages will be synced with the google drive.\n", + "Since models could be quite large, make sure that your google drive has enough space. You can apply for a gsuite drive which has unlimited space using your studentID (until 2022/07). \n", + "https://www.cc.ntu.edu.tw/chinese/services/serv_i06.asp\n", + "http://www.cc.ntu.edu.tw/english/spotlight/2016/a105038.asp" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UT0TEPRS7KF6" + }, + "source": [ + "%cd /content/drive/MyDrive \r\n", + "#change directory to google drive\r\n", + "!mkdir ML2021 #make a directory named ML2021\r\n", + "%cd ./ML2021 \r\n", + "#change directory to ML2021" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oj13Q58QerAx" + }, + "source": [ + "Use bash command pwd to output the current directory" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-S8l1-ReepkS" + }, + "source": [ + "!pwd #output the current directory" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qSSvrDaBiDrP" + }, + "source": [ + "Repeat the downloading process, this time, the file will be stored permanently in your google drive." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "b39YMYicASvP" + }, + "source": [ + "# Download the file with file_id \"1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV\", and rename it to Minori.jpg\r\n", + "!gdown --id '1duQU7xqXRsOSPYeOR0zLiSA8g_LCFzoV' --output Minori.jpg" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D0URgikZXl5I" + }, + "source": [ + "TA will provide the homework data using code similar to the code above. The data could also be stored in the google drive and loaded from there." + ] + } + ] +} \ No newline at end of file diff --git a/01 Introduction/ML2021Spring_HW1.ipynb b/01 Introduction/ML2021Spring_HW1.ipynb new file mode 100644 index 0000000..94b7712 --- /dev/null +++ b/01 Introduction/ML2021Spring_HW1.ipynb @@ -0,0 +1,874 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "ML2021Spring - HW1.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "mz0_QVkxCrX3" + }, + "source": [ + "# **Homework 1: COVID-19 Cases Prediction (Regression)**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZeZnPAiwDRWG" + }, + "source": [ + "Author: Heng-Jui Chang\n", + "\n", + "Slides: https://github.com/ga642381/ML2021-Spring/blob/main/HW01/HW01.pdf \n", + "Video: TBA\n", + "\n", + "Objectives:\n", + "* Solve a regression problem with deep neural networks (DNN).\n", + "* Understand basic DNN training tips.\n", + "* Get familiar with PyTorch.\n", + "\n", + "If any questions, please contact the TAs via TA hours, NTU COOL, or email.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jx3x1nDkG-Uy" + }, + "source": [ + "# **Download Data**\n", + "\n", + "\n", + "If the Google drive links are dead, you can download data from [kaggle](https://www.kaggle.com/c/ml2021spring-hw1/data), and upload data manually to the workspace." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tMj55YDKG6ch", + "outputId": "fc40ecc9-4756-48b1-d5c6-c169a8b453b2" + }, + "source": [ + "tr_path = 'covid.train.csv' # path to training data\n", + "tt_path = 'covid.test.csv' # path to testing data\n", + "\n", + "!gdown --id '19CCyCgJrUxtvgZF53vnctJiOJ23T5mqF' --output covid.train.csv\n", + "!gdown --id '1CE240jLm2npU-tdz81-oVKEF3T2yfT1O' --output covid.test.csv" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=19CCyCgJrUxtvgZF53vnctJiOJ23T5mqF\n", + "To: /content/covid.train.csv\n", + "100% 2.00M/2.00M [00:00<00:00, 31.7MB/s]\n", + "Downloading...\n", + "From: https://drive.google.com/uc?id=1CE240jLm2npU-tdz81-oVKEF3T2yfT1O\n", + "To: /content/covid.test.csv\n", + "100% 651k/651k [00:00<00:00, 10.2MB/s]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wS_4-77xHk44" + }, + "source": [ + "# **Import Some Packages**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "k-onQd4JNA5H" + }, + "source": [ + "# PyTorch\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "# For data preprocess\n", + "import numpy as np\n", + "import csv\n", + "import os\n", + "\n", + "# For plotting\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.pyplot import figure\n", + "\n", + "myseed = 42069 # set a random seed for reproducibility\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.backends.cudnn.benchmark = False\n", + "np.random.seed(myseed)\n", + "torch.manual_seed(myseed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(myseed)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BtE3b6JEH7rw" + }, + "source": [ + "# **Some Utilities**\n", + "\n", + "You do not need to modify this part." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FWMT3uf1NGQp" + }, + "source": [ + "def get_device():\n", + " ''' Get device (if GPU is available, use GPU) '''\n", + " return 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "def plot_learning_curve(loss_record, title=''):\n", + " ''' Plot learning curve of your DNN (train & dev loss) '''\n", + " total_steps = len(loss_record['train'])\n", + " x_1 = range(total_steps)\n", + " x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])]\n", + " figure(figsize=(6, 4))\n", + " plt.plot(x_1, loss_record['train'], c='tab:red', label='train')\n", + " plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev')\n", + " plt.ylim(0.0, 5.)\n", + " plt.xlabel('Training steps')\n", + " plt.ylabel('MSE loss')\n", + " plt.title('Learning curve of {}'.format(title))\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "\n", + "def plot_pred(dv_set, model, device, lim=35., preds=None, targets=None):\n", + " ''' Plot prediction of your DNN '''\n", + " if preds is None or targets is None:\n", + " model.eval()\n", + " preds, targets = [], []\n", + " for x, y in dv_set:\n", + " x, y = x.to(device), y.to(device)\n", + " with torch.no_grad():\n", + " pred = model(x)\n", + " preds.append(pred.detach().cpu())\n", + " targets.append(y.detach().cpu())\n", + " preds = torch.cat(preds, dim=0).numpy()\n", + " targets = torch.cat(targets, dim=0).numpy()\n", + "\n", + " figure(figsize=(5, 5))\n", + " plt.scatter(targets, preds, c='r', alpha=0.5)\n", + " plt.plot([-0.2, lim], [-0.2, lim], c='b')\n", + " plt.xlim(-0.2, lim)\n", + " plt.ylim(-0.2, lim)\n", + " plt.xlabel('ground truth value')\n", + " plt.ylabel('predicted value')\n", + " plt.title('Ground Truth v.s. Prediction')\n", + " plt.show()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "39U_XFX6KOoj" + }, + "source": [ + "# **Preprocess**\n", + "\n", + "We have three kinds of datasets:\n", + "* `train`: for training\n", + "* `dev`: for validation\n", + "* `test`: for testing (w/o target value)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TQ-MdwpLL7Dt" + }, + "source": [ + "## **Dataset**\n", + "\n", + "The `COVID19Dataset` below does:\n", + "* read `.csv` files\n", + "* extract features\n", + "* split `covid.train.csv` into train/dev sets\n", + "* normalize features\n", + "\n", + "Finishing `TODO` below might make you pass medium baseline." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0zlpIp9ANJRU" + }, + "source": [ + "class COVID19Dataset(Dataset):\n", + " ''' Dataset for loading and preprocessing the COVID19 dataset '''\n", + " def __init__(self,\n", + " path,\n", + " mode='train',\n", + " target_only=False):\n", + " self.mode = mode\n", + "\n", + " # Read data into numpy arrays\n", + " with open(path, 'r') as fp:\n", + " data = list(csv.reader(fp))\n", + " data = np.array(data[1:])[:, 1:].astype(float)\n", + " \n", + " if not target_only:\n", + " feats = list(range(93))\n", + " else:\n", + " # TODO: Using 40 states & 2 tested_positive features (indices = 57 & 75)\n", + " pass\n", + "\n", + " if mode == 'test':\n", + " # Testing data\n", + " # data: 893 x 93 (40 states + day 1 (18) + day 2 (18) + day 3 (17))\n", + " data = data[:, feats]\n", + " self.data = torch.FloatTensor(data)\n", + " else:\n", + " # Training data (train/dev sets)\n", + " # data: 2700 x 94 (40 states + day 1 (18) + day 2 (18) + day 3 (18))\n", + " target = data[:, -1]\n", + " data = data[:, feats]\n", + " \n", + " # Splitting training data into train & dev sets\n", + " if mode == 'train':\n", + " indices = [i for i in range(len(data)) if i % 10 != 0]\n", + " elif mode == 'dev':\n", + " indices = [i for i in range(len(data)) if i % 10 == 0]\n", + " \n", + " # Convert data into PyTorch tensors\n", + " self.data = torch.FloatTensor(data[indices])\n", + " self.target = torch.FloatTensor(target[indices])\n", + "\n", + " # Normalize features (you may remove this part to see what will happen)\n", + " self.data[:, 40:] = \\\n", + " (self.data[:, 40:] - self.data[:, 40:].mean(dim=0, keepdim=True)) \\\n", + " / self.data[:, 40:].std(dim=0, keepdim=True)\n", + "\n", + " self.dim = self.data.shape[1]\n", + "\n", + " print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'\n", + " .format(mode, len(self.data), self.dim))\n", + "\n", + " def __getitem__(self, index):\n", + " # Returns one sample at a time\n", + " if self.mode in ['train', 'dev']:\n", + " # For training\n", + " return self.data[index], self.target[index]\n", + " else:\n", + " # For testing (no target)\n", + " return self.data[index]\n", + "\n", + " def __len__(self):\n", + " # Returns the size of the dataset\n", + " return len(self.data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AlhTlkE7MDo3" + }, + "source": [ + "## **DataLoader**\n", + "\n", + "A `DataLoader` loads data from a given `Dataset` into batches.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "hlhLk5t6MBX3" + }, + "source": [ + "def prep_dataloader(path, mode, batch_size, n_jobs=0, target_only=False):\n", + " ''' Generates a dataset, then is put into a dataloader. '''\n", + " dataset = COVID19Dataset(path, mode=mode, target_only=target_only) # Construct dataset\n", + " dataloader = DataLoader(\n", + " dataset, batch_size,\n", + " shuffle=(mode == 'train'), drop_last=False,\n", + " num_workers=n_jobs, pin_memory=True) # Construct dataloader\n", + " return dataloader" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SGuycwR0MeQB" + }, + "source": [ + "# **Deep Neural Network**\n", + "\n", + "`NeuralNet` is an `nn.Module` designed for regression.\n", + "The DNN consists of 2 fully-connected layers with ReLU activation.\n", + "This module also included a function `cal_loss` for calculating loss.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "49-uXYovOAI0" + }, + "source": [ + "class NeuralNet(nn.Module):\n", + " ''' A simple fully-connected deep neural network '''\n", + " def __init__(self, input_dim):\n", + " super(NeuralNet, self).__init__()\n", + "\n", + " # Define your neural network here\n", + " # TODO: How to modify this model to achieve better performance?\n", + " self.net = nn.Sequential(\n", + " nn.Linear(input_dim, 64),\n", + " nn.ReLU(),\n", + " nn.Linear(64, 1)\n", + " )\n", + "\n", + " # Mean squared error loss\n", + " self.criterion = nn.MSELoss(reduction='mean')\n", + "\n", + " def forward(self, x):\n", + " ''' Given input of size (batch_size x input_dim), compute output of the network '''\n", + " return self.net(x).squeeze(1)\n", + "\n", + " def cal_loss(self, pred, target):\n", + " ''' Calculate loss '''\n", + " # TODO: you may implement L2 regularization here\n", + " return self.criterion(pred, target)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DvFWVjZ5Nvga" + }, + "source": [ + "# **Train/Dev/Test**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MAM8QecJOyqn" + }, + "source": [ + "## **Training**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lOqcmYzMO7jB" + }, + "source": [ + "def train(tr_set, dv_set, model, config, device):\n", + " ''' DNN training '''\n", + "\n", + " n_epochs = config['n_epochs'] # Maximum number of epochs\n", + "\n", + " # Setup optimizer\n", + " optimizer = getattr(torch.optim, config['optimizer'])(\n", + " model.parameters(), **config['optim_hparas'])\n", + "\n", + " min_mse = 1000.\n", + " loss_record = {'train': [], 'dev': []} # for recording training loss\n", + " early_stop_cnt = 0\n", + " epoch = 0\n", + " while epoch < n_epochs:\n", + " model.train() # set model to training mode\n", + " for x, y in tr_set: # iterate through the dataloader\n", + " optimizer.zero_grad() # set gradient to zero\n", + " x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)\n", + " pred = model(x) # forward pass (compute output)\n", + " mse_loss = model.cal_loss(pred, y) # compute loss\n", + " mse_loss.backward() # compute gradient (backpropagation)\n", + " optimizer.step() # update model with optimizer\n", + " loss_record['train'].append(mse_loss.detach().cpu().item())\n", + "\n", + " # After each epoch, test your model on the validation (development) set.\n", + " dev_mse = dev(dv_set, model, device)\n", + " if dev_mse < min_mse:\n", + " # Save model if your model improved\n", + " min_mse = dev_mse\n", + " print('Saving model (epoch = {:4d}, loss = {:.4f})'\n", + " .format(epoch + 1, min_mse))\n", + " torch.save(model.state_dict(), config['save_path']) # Save model to specified path\n", + " early_stop_cnt = 0\n", + " else:\n", + " early_stop_cnt += 1\n", + "\n", + " epoch += 1\n", + " loss_record['dev'].append(dev_mse)\n", + " if early_stop_cnt > config['early_stop']:\n", + " # Stop training if your model stops improving for \"config['early_stop']\" epochs.\n", + " break\n", + "\n", + " print('Finished training after {} epochs'.format(epoch))\n", + " return min_mse, loss_record" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0hSd4Bn3O2PL" + }, + "source": [ + "## **Validation**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yrxrD3YsN3U2" + }, + "source": [ + "def dev(dv_set, model, device):\n", + " model.eval() # set model to evalutation mode\n", + " total_loss = 0\n", + " for x, y in dv_set: # iterate through the dataloader\n", + " x, y = x.to(device), y.to(device) # move data to device (cpu/cuda)\n", + " with torch.no_grad(): # disable gradient calculation\n", + " pred = model(x) # forward pass (compute output)\n", + " mse_loss = model.cal_loss(pred, y) # compute loss\n", + " total_loss += mse_loss.detach().cpu().item() * len(x) # accumulate loss\n", + " total_loss = total_loss / len(dv_set.dataset) # compute averaged loss\n", + "\n", + " return total_loss" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g0pdrhQAO41L" + }, + "source": [ + "## **Testing**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "aSBMRFlYN5tB" + }, + "source": [ + "def test(tt_set, model, device):\n", + " model.eval() # set model to evalutation mode\n", + " preds = []\n", + " for x in tt_set: # iterate through the dataloader\n", + " x = x.to(device) # move data to device (cpu/cuda)\n", + " with torch.no_grad(): # disable gradient calculation\n", + " pred = model(x) # forward pass (compute output)\n", + " preds.append(pred.detach().cpu()) # collect prediction\n", + " preds = torch.cat(preds, dim=0).numpy() # concatenate all predictions and convert to a numpy array\n", + " return preds" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SvckkF5dvf0j" + }, + "source": [ + "# **Setup Hyper-parameters**\n", + "\n", + "`config` contains hyper-parameters for training and the path to save your model." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "NPXpdumwPjE7" + }, + "source": [ + "device = get_device() # get the current available device ('cpu' or 'cuda')\n", + "os.makedirs('models', exist_ok=True) # The trained model will be saved to ./models/\n", + "target_only = False # TODO: Using 40 states & 2 tested_positive features\n", + "\n", + "# TODO: How to tune these hyper-parameters to improve your model's performance?\n", + "config = {\n", + " 'n_epochs': 3000, # maximum number of epochs\n", + " 'batch_size': 270, # mini-batch size for dataloader\n", + " 'optimizer': 'SGD', # optimization algorithm (optimizer in torch.optim)\n", + " 'optim_hparas': { # hyper-parameters for the optimizer (depends on which optimizer you are using)\n", + " 'lr': 0.001, # learning rate of SGD\n", + " 'momentum': 0.9 # momentum for SGD\n", + " },\n", + " 'early_stop': 200, # early stopping epochs (the number epochs since your model's last improvement)\n", + " 'save_path': 'models/model.pth' # your model will be saved here\n", + "}" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6j1eOV3TOH-j" + }, + "source": [ + "# **Load data and model**" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eNrYBMmePLKm", + "outputId": "fcd4f175-4f7e-4306-f33c-5f8285f11dce" + }, + "source": [ + "tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], target_only=target_only)\n", + "dv_set = prep_dataloader(tr_path, 'dev', config['batch_size'], target_only=target_only)\n", + "tt_set = prep_dataloader(tt_path, 'test', config['batch_size'], target_only=target_only)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Finished reading the train set of COVID19 Dataset (2430 samples found, each dim = 93)\n", + "Finished reading the dev set of COVID19 Dataset (270 samples found, each dim = 93)\n", + "Finished reading the test set of COVID19 Dataset (893 samples found, each dim = 93)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FHylSirLP9oh" + }, + "source": [ + "model = NeuralNet(tr_set.dataset.dim).to(device) # Construct model and move to device" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sX2B_zgSOPTJ" + }, + "source": [ + "# **Start Training!**" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GrEbUxazQAAZ", + "outputId": "f4f3bd74-2d97-4275-b69f-6609976b91f9" + }, + "source": [ + "model_loss, model_loss_record = train(tr_set, dv_set, model, config, device)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Saving model (epoch = 1, loss = 74.9742)\n", + "Saving model (epoch = 2, loss = 50.5313)\n", + "Saving model (epoch = 3, loss = 29.1148)\n", + "Saving model (epoch = 4, loss = 15.8134)\n", + "Saving model (epoch = 5, loss = 9.5430)\n", + "Saving model (epoch = 6, loss = 6.8086)\n", + "Saving model (epoch = 7, loss = 5.3892)\n", + "Saving model (epoch = 8, loss = 4.5267)\n", + "Saving model (epoch = 9, loss = 3.9454)\n", + "Saving model (epoch = 10, loss = 3.5560)\n", + "Saving model (epoch = 11, loss = 3.2303)\n", + "Saving model (epoch = 12, loss = 2.9920)\n", + "Saving model (epoch = 13, loss = 2.7737)\n", + "Saving model (epoch = 14, loss = 2.6181)\n", + "Saving model (epoch = 15, loss = 2.3987)\n", + "Saving model (epoch = 16, loss = 2.2712)\n", + "Saving model (epoch = 17, loss = 2.1349)\n", + "Saving model (epoch = 18, loss = 2.0210)\n", + "Saving model (epoch = 19, loss = 1.8848)\n", + "Saving model (epoch = 20, loss = 1.7999)\n", + "Saving model (epoch = 21, loss = 1.7510)\n", + "Saving model (epoch = 22, loss = 1.6787)\n", + "Saving model (epoch = 23, loss = 1.6450)\n", + "Saving model (epoch = 24, loss = 1.6030)\n", + "Saving model (epoch = 26, loss = 1.5052)\n", + "Saving model (epoch = 27, loss = 1.4486)\n", + "Saving model (epoch = 28, loss = 1.4069)\n", + "Saving model (epoch = 29, loss = 1.3733)\n", + "Saving model (epoch = 30, loss = 1.3533)\n", + "Saving model (epoch = 31, loss = 1.3335)\n", + "Saving model (epoch = 32, loss = 1.3011)\n", + "Saving model (epoch = 33, loss = 1.2711)\n", + "Saving model (epoch = 35, loss = 1.2331)\n", + "Saving model (epoch = 36, loss = 1.2235)\n", + "Saving model (epoch = 38, loss = 1.2180)\n", + "Saving model (epoch = 39, loss = 1.2018)\n", + "Saving model (epoch = 40, loss = 1.1651)\n", + "Saving model (epoch = 42, loss = 1.1631)\n", + "Saving model (epoch = 43, loss = 1.1394)\n", + "Saving model (epoch = 46, loss = 1.1129)\n", + "Saving model (epoch = 47, loss = 1.1107)\n", + "Saving model (epoch = 49, loss = 1.1091)\n", + "Saving model (epoch = 50, loss = 1.0838)\n", + "Saving model (epoch = 52, loss = 1.0692)\n", + "Saving model (epoch = 53, loss = 1.0681)\n", + "Saving model (epoch = 55, loss = 1.0537)\n", + "Saving model (epoch = 60, loss = 1.0457)\n", + "Saving model (epoch = 61, loss = 1.0366)\n", + "Saving model (epoch = 63, loss = 1.0359)\n", + "Saving model (epoch = 64, loss = 1.0111)\n", + "Saving model (epoch = 69, loss = 1.0072)\n", + "Saving model (epoch = 72, loss = 0.9760)\n", + "Saving model (epoch = 76, loss = 0.9672)\n", + "Saving model (epoch = 79, loss = 0.9584)\n", + "Saving model (epoch = 80, loss = 0.9526)\n", + "Saving model (epoch = 82, loss = 0.9494)\n", + "Saving model (epoch = 83, loss = 0.9426)\n", + "Saving model (epoch = 88, loss = 0.9398)\n", + "Saving model (epoch = 89, loss = 0.9223)\n", + "Saving model (epoch = 95, loss = 0.9111)\n", + "Saving model (epoch = 98, loss = 0.9034)\n", + "Saving model (epoch = 101, loss = 0.9014)\n", + "Saving model (epoch = 105, loss = 0.9011)\n", + "Saving model (epoch = 106, loss = 0.8933)\n", + "Saving model (epoch = 110, loss = 0.8893)\n", + "Saving model (epoch = 117, loss = 0.8867)\n", + "Saving model (epoch = 118, loss = 0.8867)\n", + "Saving model (epoch = 121, loss = 0.8790)\n", + "Saving model (epoch = 126, loss = 0.8642)\n", + "Saving model (epoch = 130, loss = 0.8627)\n", + "Saving model (epoch = 137, loss = 0.8616)\n", + "Saving model (epoch = 139, loss = 0.8534)\n", + "Saving model (epoch = 147, loss = 0.8467)\n", + "Saving model (epoch = 154, loss = 0.8463)\n", + "Saving model (epoch = 155, loss = 0.8408)\n", + "Saving model (epoch = 167, loss = 0.8354)\n", + "Saving model (epoch = 176, loss = 0.8314)\n", + "Saving model (epoch = 191, loss = 0.8267)\n", + "Saving model (epoch = 200, loss = 0.8212)\n", + "Saving model (epoch = 226, loss = 0.8190)\n", + "Saving model (epoch = 230, loss = 0.8144)\n", + "Saving model (epoch = 244, loss = 0.8136)\n", + "Saving model (epoch = 258, loss = 0.8095)\n", + "Saving model (epoch = 269, loss = 0.8076)\n", + "Saving model (epoch = 285, loss = 0.8064)\n", + "Saving model (epoch = 330, loss = 0.8055)\n", + "Saving model (epoch = 347, loss = 0.8053)\n", + "Saving model (epoch = 359, loss = 0.7992)\n", + "Saving model (epoch = 410, loss = 0.7989)\n", + "Saving model (epoch = 442, loss = 0.7966)\n", + "Saving model (epoch = 447, loss = 0.7966)\n", + "Saving model (epoch = 576, loss = 0.7958)\n", + "Saving model (epoch = 596, loss = 0.7929)\n", + "Saving model (epoch = 600, loss = 0.7893)\n", + "Saving model (epoch = 683, loss = 0.7825)\n", + "Saving model (epoch = 878, loss = 0.7817)\n", + "Saving model (epoch = 904, loss = 0.7794)\n", + "Saving model (epoch = 931, loss = 0.7790)\n", + "Saving model (epoch = 951, loss = 0.7781)\n", + "Saving model (epoch = 965, loss = 0.7771)\n", + "Saving model (epoch = 1018, loss = 0.7717)\n", + "Saving model (epoch = 1168, loss = 0.7653)\n", + "Saving model (epoch = 1267, loss = 0.7645)\n", + "Saving model (epoch = 1428, loss = 0.7644)\n", + "Saving model (epoch = 1461, loss = 0.7635)\n", + "Saving model (epoch = 1484, loss = 0.7629)\n", + "Saving model (epoch = 1493, loss = 0.7590)\n", + "Finished training after 1694 epochs\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 295 + }, + "id": "hsNO9nnXQBvP", + "outputId": "1626def6-94c7-4a87-9447-d939f827c8eb" + }, + "source": [ + "plot_learning_curve(model_loss_record, title='deep model')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3hUVfrA8e87k0kPIYGAkNC7oCBFUWzYxV5RF1fXVVZd2+7ay/5QV1ddK7oCroJrAbGLCq4VFQWUXiT0klATSEJ6ZjLn98e9CZNkUkhmMsnwfp4nD3duO++9Ifedc86954oxBqWUUqo6R6gDUEop1TJpglBKKeWXJgillFJ+aYJQSinllyYIpZRSfmmCUEop5ZcmCBUQInKCiKwNdRwthYiMEpH1IlIgIhc2YP3XReQfzRFbcxGRuSJyfQPXNSLSO9gxqYOjCSIMiMgWETktlDEYY340xvQLZQwtzCPAS8aYeGPMx6EORqnG0AShGkREnKGOoama+Ri6AaubsTylAk4TRBgTEYeI3CsiG0Vkr4i8KyLJPsvfE5FdIpInIj+IyECfZa+LyCQRmS0ihcBou6Zyp4issLeZKSLR9voni0imz/a1rmsvv1tEdorIDhG5vq4mBhFJFpFp9ro5IvKxPf9aEZlXbd3K/fg5hjvt43X6rH+RiKxoyPnyE9cNIrJBRPaJyCwR6WzP3wj0BD61m5ii/Gx7lIgsEZF8EZkJRFdbfq6ILBORXBH5WUSO9FnWWUQ+EJEsEdksIrf5LJsgIu/b5zvfLmNwHcdgRORmuzksX0QeFZFedpn77XMQWd8x28tOF5F0+/f9EiDVyrpORNbYv8P/iUi32uJSLYQxRn9a+Q+wBTjNz/zbgQVAGhAFTAFm+Cy/Dkiwlz0PLPNZ9jqQB4zC+iIRbZfzC9AZSAbWADfa658MZFaLqbZ1zwJ2AQOBWOAtwAC9azm+z4GZQBLgAk6y518LzKu2buV+ajmGjcDpPuu/B9zbkPNVrZxTgGxgqL3ui8AP9f1O7GWRwFbgL/bxXAq4gX/Yy48C9gDHAE7gGnt/UfZxLAb+bu+nJ7AJONPedoK9r0vtfd8JbAZctcRigE+ANvbvoxT4xt5vIvAbcE19xwy0B/J9yv0L4AGut5dfAGwABgARwIPAz/5+b/rTcn5CHoD+BOCXWHuCWAOc6vO5k33xiPCzblv7jzTR/vw68Iafcsb5fH4KmGxPn0zNBFHbulOBf/os613bBcKO2Qsk+Vl2LfUniOrH8A9gqj2dABQC3Rpxvl4DnvL5HG+v272u34m97ERgByA+837mQIKYBDxabZu1wElYSWNbtWX3AdPs6QnAAp9lDmAncEItsRhglM/nxcA9Pp+fAZ6v75iB31crV4BMDiSIOcAfq8VV5HPuNUG0wB9tYgpv3YCP7GaKXKwLYDnQUUScIvKE3ZyyH+uCBtY3wQoZfva5y2e6COsiUZva1u1cbd/+yqnQBdhnjMmpY526VN/3dOBiu9nnYmCJMWarvazW8+Vnv52xagEAGGMKgL1AagNi6gxsN/aV0bbVZ7ob8LeKOOxYutjbdQM6V1t2f7UYK4/ZGOPFulB3pna7faaL/Xz2/b3VdsxVfqf2sfme+27ACz4x78NKIg05XypEIkIdgAqqDOA6Y8xP1ReIyNVY1f7TsJJDIpBD1XbjYA31uxOrGadClzrWzQCSRaStMSa32rJCrCYqAETkMD/bVzkGY8xvIrIVOBu4Cith+Jbl93z5sQProldRdhzQDtjegG13AqkiIj5JoitW81dFHI8ZYx6rvqGIHAtsNsb0qWP/XXzWd2Cd6x0NiKs+dR3zzmrlClV/rxXH9HYA4lDNRGsQ4cMlItE+PxHAZOCxis5AEUkRkQvs9ROw2pv3Yl1kH2/GWN8F/iAiA0QkFniothWNMTuxmideFpEkEXGJyIn24uXAQBEZYneAT2hg+dOx+htOxOqDqFDX+apuhn0MQ+zayOPAQmPMlgaUPx+rff42+3guBo72Wf4f4EYROUYscSJyjogkYPXr5IvIPSISY9cEB4nICJ/th4nIxfb/gTuwfs8LGhBXfeo65s+xfhcV5d4G+CbsycB9Yt8IISKJInJZAGJSQaQJInzMxmoOqPiZALwAzAK+FJF8rIvEMfb6b2A1F2zH6ogMxAWkQYwxc4CJwHdYHZcVZZfWssnVWG3d6Vidt3fY+1mH9bzB18B6YF4t21c3A6s9/1tjTLbP/LrOV/Vj+BorsX2A9e25F3BFQwo3xpRhNW9di9XUMhb40Gf5IuAG4CWsWt0Ge12MMeXAucAQrM7nbOBVrBpghU/sfeZgnbuLjTHuhsRWT9y1HrN9Hi8DnsD60tEH+Mln24+AJ4F37CbNVVi1ONWCSdVmUKWan4gMwLpgRBljPKGOpzUTkQlYnb3jQh2Lav20BqFCQqznD6JEJAnrm+WnmhyUalmCmiDEelhqpf3Az6JglqVanT9hNRdtxLpT6KbQhqOUqi6oTUwisgUYXq2dVymlVCugTUxKKaX8CnYNYjPWnRQGmGKMecXPOuOB8QBxcXHD+vfv36iy1mbn4PSW07tD+/pXVkqpMLF48eJsY0xKMPYd7ASRaozZLiIdgK+AW40xP9S2/vDhw82iRY3rqjh52vu0KSxg1i3XNi5YpZRqhURksTFmeDD2HdQmJmPMdvvfPcBHVH0YKMCFgak6eKRSSqkmCFqCsJ/+TKiYBs7Autc9OOVhMJoflFIqYII5FlNHrIHPKsqZboz5IliFiTGgNQillAqYoCUIY8wmoNYXlQSaGDCiCUKpQ43b7SYzM5OSkpJQhxJU0dHRpKWl4XK5mq3MsBnNVYxXm5iUOgRlZmaSkJBA9+7dkTD9kmiMYe/evWRmZtKjR49mKzdsnoMQwIgDHVtKqUNLSUkJ7dq1C9vkACAitGvXrtlrSeGTIIwJ2ssLlFItWzgnhwqhOMawSRBU9EFoDUIppQIibBKEYPQ5CKVUs8vNzeXll18+6O3GjBlDbm71lyS2LOGTIIyxOyK0BqGUaj61JQiPp+7R62fPnk3btm2DFVZAhM9dTOhtrkqp5nfvvfeyceNGhgwZgsvlIjo6mqSkJNLT01m3bh0XXnghGRkZlJSUcPvttzN+/HgAunfvzqJFiygoKODss8/m+OOP5+effyY1NZVPPvmEmJiYEB9ZOCUI48Xg1BqEUoewXY8/Tuma9IDuM2pAfw67//5alz/xxBOsWrWKZcuWMXfuXM455xxWrVpVeTvq1KlTSU5Opri4mBEjRnDJJZfQrl27KvtYv349M2bM4D//+Q+XX345H3zwAePGhf6lgGGUIKzbXJVSKpSOPvroKs8qTJw4kY8++giAjIwM1q9fXyNB9OjRgyFDhgAwbNgwtmzZ0mzx1iWMEoQ9FpPWIJQ6ZNX1Tb+5xMXFVU7PnTuXr7/+mvnz5xMbG8vJJ5/s91mGqKioymmn00lxcXGzxFqfMPrKbfQ2V6VUs0tISCA/P9/vsry8PJKSkoiNjSU9PZ0FCxY0c3RNE0Y1CB3uWynV/Nq1a8eoUaMYNGgQMTExdOzYsXLZWWedxeTJkxkwYAD9+vVj5MiRIYz04IVPgsC6zVXHdFVKNbfp06f7nR8VFcWcOXP8LqvoZ2jfvj2rVh14E8Kdd94Z8PgaK2yamKyhNoTiZctCHYpSSoWFsEoQXofg3b8/1KEopVRYCJsE4bBfGKSjuSqlVGCETYIQrxevCHg1QSilVCCETYJwGGM9KKc1CKWUCoiwSRAVfRBKKaUCI7wShDhAXxuklAqxCRMm8PTTT4c6jCYLmwTh9HoxWoNQSqmACZsEIcbqpDbl5aEORSl1CHrsscfo27cvxx9/PGvXrgVg48aNnHXWWQwbNowTTjiB9PR08vLy6NatG16vF4DCwkK6dOmC2+0OZfh+hc2T1BWd1HunTCbxnHNCHY5SKgQeWp/JqoLADnQ3KD6GR/uk1bnO4sWLeeedd1i2bBkej4ehQ4cybNgwxo8fz+TJk+nTpw8LFy7k5ptv5ttvv2XIkCF8//33jB49ms8++4wzzzwTl8sV0LgDIWwShNUHIZSuWxfqUJRSh5gff/yRiy66iNjYWADOP/98SkpK+Pnnn7nssssq1ystLQVg7NixzJw5k9GjR/POO+9w8803hyTu+oRNgnB4vfpGOaUOcfV9029OXq+Xtm3bsszP8D/nn38+999/P/v27WPx4sWccsopIYiwfmHUB2Eod4TN4SilWpETTzyRjz/+mOLiYvLz8/n000+JjY2lR48evPfeewAYY1i+fDkA8fHxjBgxgttvv51zzz0Xp9MZyvBrFTZXVKsGETaHo5RqRYYOHcrYsWMZPHgwZ599NiNGjADg7bff5rXXXmPw4MEMHDiQTz75pHKbsWPH8tZbbzF27NhQhV2v8GlisvsglFIqFB544AEeeOCBGvO/+OILv+tfeumlLX7suLD5yi3G6HMQSikVQGGTIBzGaz9JrZRSKhDC5oqqYzEpdehq6U01gRCKYwybBFE5mqtS6pASHR3N3r17wzpJGGPYu3cv0dHRzVpu+HRSV7wPQil1SElLSyMzM5OsrKxQhxJU0dHRpKU173Me4ZMgjD4op9ShyOVy0aNHj1CHEZbCpk1GvEY7qZVSKoCCfkUVEaeILBWRz4JZjtUHoTUIpZQKlOb4yn07sCbYhYjx6lAbSikVQEG9oopIGnAO8GowywEdrE8ppQIt2F+5nwfuBry1rSAi40VkkYgsaspdCA5jMA6HvnBUKaUCJGgJQkTOBfYYYxbXtZ4x5hVjzHBjzPCUlJTGl2ffA621CKWUCoxg1iBGAeeLyBbgHeAUEXkrWIVF9+oFaIJQSqlACVqCMMbcZ4xJM8Z0B64AvjXGjAtWeRWNS/qwnFJKBUbY3PbjsDsfvHonk1JKBUSzPEltjJkLzA1mGQ5j9YPrw3JKKRUYYXM1dWgntVJKBVTYJIiKu5i0D0IppQIjbBJExYEY7YNQSqmACJur6YEaRNgcklJKhVTYXE2dFZ3U+lY5pZQKiLBJEBQWAtpJrZRSgRI2CUK8epurUkoFUthcTSsShNYglFIqMMIuQehtrkopFRhhkyAqHpTToTaUUiowwuZq6igvB7QGoZRSgRI2CUK8VoIw2kmtlFIBETZX08oH5fQ5CKWUCoiwSRDOcquTutzhDHEkSikVHsInQdhNTNpJrZRSgRE2V9P4wYMBKNcEoZRSARE2V9PYXj0BbWJSSqlACZsEEWG/k7rcGTaHpJRSIRU2V1OnfReT1iCUUiowwihBWP+WOzVBKKVUIIRNgqhsYtJOaqWUCoiwuZpW1Bu0iUkppQIjjBJERSe1JgillAqEsEkQ0e3bA+DRBKGUUgERNgkipls3QPsglFIqUMLmahphj9GnfRBKKRUYYZQgrAyhfRBKKRUY4ZcgtIlJKaUCImyupgcShNYglFIqEMImQTgrm5jC5pCUUiqkwuZqqp3USikVWGGUIKwM4dUEoZRSARE2CUKbmJRSKrDC5mqqndRKKRVYQUsQIhItIr+IyHIRWS0iDwerLABnRR+EPgehlFIBERHEfZcCpxhjCkTEBcwTkTnGmAXBKMwhgsPr1ecglFIqQIKWIIwxBiiwP7rsHxOs8gCc5eXaxKSUUgES1K/bIuIUkWXAHuArY8xCP+uMF5FFIrIoKyurSeU5veXaSa2UUgES1KupMabcGDMESAOOFpFBftZ5xRgz3BgzPCUlpUnlOb1erUEopVSAHFSCEBGHiLQ52EKMMbnAd8BZB7vtwahoYnJv3x7MYpRS6pBQb4IQkeki0kZE4oBVwG8iclcDtksRkbb2dAxwOpDe1IDr4vR6KXc68JaWBrMYpZQ6JDSkBnG4MWY/cCEwB+gBXN2A7ToB34nICuBXrD6IzxodaQM4veX6RjmllAqQhtzF5LJvU70QeMkY4xaReu9GMsasAI5qaoAHQ+9iUkqpwGlIDWIKsAWIA34QkW7A/mAG1ViVndQmqHfTKqXUIaHeGoQxZiIw0WfWVhEZHbyQGs9Z7rFuc9UEoZRSTdaQTurb7U5qEZHXRGQJcEozxHbQ9DZXpZQKnIY0MV1nd1KfASRhdVA/EdSoGsnqg9AahFJKBUJDEoQ9DB5jgDeNMat95rUoTq8XT0Qwh5dSSqlDR0MSxGIR+RIrQfxPRBIAb3DDahyXx41bE4RSSgVEQ66mfwSGAJuMMUUi0g74Q3DDapxIt5uyCFeow1BKqbDQkLuYvCKSBlwl1kt5vjfGfBr0yBoh0uOmKDqGosVLiOrTJ9ThKKVUq9aQu5ieAG4HfrN/bhORx4MdWGO4PB7cES6ynnsu1KEopVSr15AmpjHAEGOMF0BE/gssBe4PZmCN4fK4cbsiQFpkH7pSSrUqDR3Nta3PdGIwAgkEl8dDWYQLcWk/hFJKNVVDahD/BJaKyHdYt7eeCNwb1KgayeW27mLyFheHOhSllGr1GtJJPUNE5gIj7Fn3GGN2BTWqRoq0+yC8BQX1r6yUUqpOtSYIERlabVam/W9nEelsjFkSvLAax+VxU6bNS0opFRB11SCeqWOZoQWOxxTpcePW5yCUUiogak0QxpgWOWJrXVxuN16HwxqPSSmlVJOE1ZXU5fEA6NPUSikVAGGWINwAuGNiQhyJUkq1fmGVICLtGoRbR/tWSqkmqzVBiMg4n+lR1ZbdEsygGquyBqEjuiqlVJPVVYP4q8/0i9WWXReEWJqsIkGUuSJDHIlSSrV+dSUIqWXa3+cWIdJtJwitQSilVJPVlSBMLdP+PrcIFXcx6bMQSinVdHV91e4vIiuwagu97Gnszz2DHlkjRFb0Qbi0BqGUUk1V15V0QLNFESAHOqm1BqGUUk1V15PUW30/268aPRHYZoxZHOzAGsPl1gfllFIqUOq6zfUzERlkT3cCVmHdvfSmiNzRTPEdFL3NVSmlAqeuTuoexphV9vQfgK+MMecBx9BCb3OteFBOR3RVSqmmqytBuH2mTwVmAxhj8gFvMINqrMSjhgBWH4R7+/YQR6OUUq1bXW0xGSJyK9Z7IIYCXwCISAzQIr+iR1qvzbZeGlRSEuJolFKqdaurBvFHYCBwLTDWGJNrzx8JTAtyXI0SaazHM8oiIsC0yEc1lFKq1ajrLqY9wI1+5n8HfBfMoBqrym2umiCUUqpJ6nrl6Ky6NjTGnB/4cJrmwJPUERhNEEop1SR19UEcC2QAM4CFtNDxl3xFGIPD68XtcrXQwUCUUqr1qCtBHAacDlwJXAV8DswwxqxujsAaxRhcHrf1oFy5J9TRKKVUq1ZrJ7UxptwY84Ux5hqsjukNwNyGvgtCRLqIyHci8puIrBaR2wMUc62iBw0k0u3GHeHClLfIO3GVUqrVqPORYxGJAs7BqkV0ByYCHzVw3x7gb8aYJSKSACwWka+MMb81Id46RR9+OK4it30XkyYIpZRqirqG2ngDmI/1DMTDxpgRxphHjTENegLNGLPTGLPEns4H1gCpAYi5VrEjRuDyeHC7XGRNrP6OI6WUUgejrucgxgF9gNuBn0Vkv/2TLyL7D6YQEekOHIXV2V192XgRWSQii7Kysg5mtzVEJCUR6XHjjoigcN68Ju1LKaUOdXU9B1FX8mgwEYkHPgDuMMbUSCzGmFeAVwCGDx/e5HuPXG6PDvetlFIBEJAkUBsRcWElh7eNMR8Gs6wKLrsGoZRSqmmCliBERIDXgDXGmGeDVU51Lo+bMldkcxWnlFJhK5g1iFHA1cApIrLM/hkTxPIA67WjZVqDUEqpJgvaldQYM48QPH0dU1JCbnxicxerlFJhJ6h9EKEQX1xEYUxMqMNQSqlWL+wSRFxxEQUxsQA6YJ9SSjVBWCaIougYvCI65LdSSjVB2CWI+OIijMNBUVQ0eHW4DaWUaqywSxBxxUUAFMbEaoJQSqkmCLsEEW8niILYOPZ/+VWIo1FKqdYrbBNEYUwMO+68M8TRKKVU6xV+CaLIrkHExIU4EqWUat3CLkEk5ecCsK9N2xBHopRSrVvYJYjkvDwcXi9ZScmhDkUppVq1sEsQEd5yvA4Hb5xzSahDUUqpVi3sEoRSSqnACMsEcfG3c4grKgx1GEop1aqFZYKILS2hODoGHWhDKaUaL+wShKNNG2KLi/E6HJRERoU6HKWUarXCLkFEdEghtrQYgOLoaDZfelmII1JKqdYp7BKEIyqa2GIrQRRFxVCyalWII1JKqdYp7BIEIsSVWAmi0H4vhFJKqYMXfgnC4aBtfh4A+9roq0eVUqqxwi9BCKTk7AMgq631NHX2pEmhjEgppVql8EsQQPL+XBzl5WQltwMg64WJIY5IKaVan7BLEILgNIaOOdlsT+kY6nCUUqrVCrsEETNsGAA9dmSyuXOXyvnZkydj3O5QhaWUUq1O2CWIDn/9CwDdd2SQ0bETbqcTgKznX2DvtNdDGJlSSrUuYZcgJCICgLQ9uyh3RjDrxNMql2U9+2yowlJKqVYn7BJEhbTdOwF48+yLq8zf+X8TyJn5bihCUkqpViUsE0SXKZM5YtM6APIS2uBxOCuX5c6cya7/+79QhaaUUq1GWCYIV+fOVT7vbJ8SokiUUqr1CssEYYw10Pejk58B4IGb7wplOEop1SqFZYKoMGTdbwBkdOxcY1nxypXNHY5SSrUqYZ0g4ouLKqc3pnalXKTy85bLLmfD6Wfg3rmzzn14S0vJ/+67oMWolFItVXgmCJ9Xyf3zpScBuP7BJ3nmdzdUWc2dkcGOe+8DoGTtOtb0H0DBTz9VWWf3E0+QedPNFK9YEdyYlVKqhQnLBOGIiqycHrl6GbF2TWLOqNE11i1auJC8Tz8j5603Acj/+usqy93bMgAoz8sLVrhKKdUihWWCiOzevcrnu9+cUjn97bCRNdbfcddd5L73vv+dVTRLGX3DtVLq0BK0BCEiU0Vkj4iE/JVuJy39pXL60etvZ/SkGXx/1NF+1xWffgp7RjBDU0qpFiuYNYjXgbOCuP+D8uHdf6ryecL4v1Dg541zOdNnkDNjBmv6D2DDaadT+OOP1gJjKElPJ/+bbyhJT6+xXfHKlazpP4CyzO01lpVt24bxegNzIEop1UyCliCMMT8A+4K1/3rZYzJVSMrfzzc3X1Vl3tTzLqcgOqbGprsefgQAd2Zm5byMP93I5gsvIvPPt7D5wovIuPEmto0fjycnh/y5c9ly2eUAFM6bV2VfJWvXsfGMM9n72ms1yvHk5LD1mmvx7AvOaXLv3kN5bm5Q9t2cipYuxb17T6jDUOqQE/I+CBEZLyKLRGRRVlZWwPYbmZpaY57DGP791EOVnz8afSbnPTe1UfsvmDuXwh9+ZNN557Pz/gcq53v27GZN/wHsesROMtutGkXWMwcGCixasoTSzZtZf+xxFC1cyPrjRlXZt7eoiEDYcNJJrKu274YqWbeOrddci7ekpEHre7KyWHv0MeTPnVv5oGKgbL3yKjadd15A96mUql/IE4Qx5hVjzHBjzPCUlMANidHhnnv8zj988wb+O+GvVebdc8s9PHDj3yhxRVaZ/8cHnuSpcePrLKc8O5tynxpA9svW601zplvNVPnfflO5bOu1f2D9SSez9arfsensMTX2Vbx8Obsef5y1Q4dRvHIl+d9+h2fvXoqWLMUYQ/Yr/8Gzbx/7Z8/GvXs3RUuW1H0SAHyatowxmPLyA+WtWs2a/gMo3bS5RhPY7sf/SdHChRQvXVpvEaa8nJz33sO7fz+ZN95EzhtvHFjm9bL9zrsoXras/ljrOoz9+5u0vVLq4Emgv+1V2blId+AzY8yghqw/fPhws2jRooCVv/uJJ9n3+uv+lyW144rHX6ox/62/30Fmh8P4+KQzWHDEUABmPHArh+3LDlhc/sQMHkzx8uWVn9vfegvZLx6IL+WOO8h6/vka2w1IX8Oa/gMAiB05km6vT6tcVjF/QPoaALbffTf7Z33KgPQ1GGNIH3B4jX0B5H/7LZk3/xmAjvfdS/I11wCw6x+Psf/TT+m7cAHesjLE4UAiItjz9NPsffVAE1r86NF0mfQyAJ59+6rUkCrKqG7fG29Qsno1nZ98ssay6sehlDpARBYbY4YHY98hr0EEU4r98iB/OubsZdbfrq8xf9wjz3PvLfdWJgeAKx97kY2pXXn31HNY0m/gQcdRFhHBgoFD6lzHNzkAVZID4Dc5AGw47fTK6aIFC9j58MMUr15N/jcHai6F8+dTnpvL/lmfVs4r8fPgX96nn1K0ZEllciiIjmH3P5+oXJ7z1luVz4OsPXIw6YOOsNabV/XhQt9aS/W7wPx18INVY8n7ZBZlGRl+l6tDmzGG/Z7y+ldUARXM21xnAPOBfiKSKSJ/DFZZtXFERpJ40UW1Lk8oKuTbm67kwdderHdfN977GJMuHcff7niQC5+awuhJM/jfMSdULp9y4ZWMnjSDfQmJbG/focq2Yx9/iftuuYeVPftWmZ8XF1/5xrvZx57s9xmN+pRs305RVHTl59wZ77DlkkvJ/PMtlfO2/eE61o08tsp2przmXVU77rqbrVf9DoB3Tz2H856byqL+R7D199ew4fQzKtfbW61WVurnou8tK/M7jMnmC2v/fQBsPP0M3Lt317mOL8/evWRPnox7j9WJvffVVyndsKHB26vW4T+ZWfT9cSXbS8pCHcohJZh3MV1pjOlkjHEZY9KMMTVv42kGCaedWudyAU5d9DNf3PZ7Hnr1hcqnrqvz+NwVlZfQBoAnrr2Z6Wecz8VPTOKdM88H4NY7JzDu0ReYN9iq8W1v34HchEQAbrvrYR6/9mYANndK48Kn/8PD198OwL9+/ycetaf92dS5C5MvvqpyFBGPw8nCgYN5cey1nPP8NModVX+Vvww4kh3tO/DLgCNr7OvnEcew9aqrasz3NenScQAsOvwIvs8vZr/PXUR7nniSgphYyiJclPxmDYhYHBVFQXQMfx9/B7cdPoI1f7qJ304/kyl33k9ufEKdZVUojI6hLMJF7syZAJRu2FDlCXb3zp0s37OP7fmFAGQvXcYXl19F1vMvsOPOuzBuN3uefoZN51od2uX5+fWOteVPWeZ2ygusMr7KzmNa+sPqmMIAABy/SURBVGaM283/svPIaOQFal5OPhuKGtbhH0jGGEo3b27SPubl5PPwhu3kuj0BiurgfbTbuhtvV2nLeq+8MSasazdB7YM4WIHug6hQ0YZ9MDwOJ/++7Go+PvnMRpd74wdvMfmScTXmj/npW2aPOsXvNg9PeZaPTj6T/ls2snjAEQxev4Yrv5zFJU9OBuDjO8eTlZTM1yNGMfOMA3f2nDn/e85c8AMZHTqxsUs3Zp14oOnpqPRVPPvCYwB8Pmo0T48bz7+feogl/QYRVVbKuq49ufPtVyh3OLnnlnvY3qETOW0Sa8R21RefMCx9JX+740EA+m/ewCT7rrDRk2ZUWfeSb+ewrWNnfh04mN4ZW3jl8fv4bthIHr3+dm5cvpBxQwfx5etvcvVjE3DHxTHsp1UUR8cQV1TIne9MZfXdD5C+4BeemPMBLw0bxW3vvk50aSmnTJqBy+MhKTaaPWXWBesvb79KSVQUV103jskzPya6rJQLH34Ix9jLiNmymY2pXTnthWeJGWQ1D7p372bDSSfTbvx4du3YxZFPP4m3pASJjEQcDtb0H0BU//60GTOGQX1GAPD4gm+4f+SppERGMOuZ/yP5D9eSaN9ZtTy/iFx3OSclJ1BU7iXKITjtprWMm27GnZnBcX99FIBdo4fgNQaHCJ6cHCQiAmdC/Qn0PxlZPLpxB+tPPIIoR8O/1+V+8CE7H3iArtOmEnesVYtclFfInjI3Y1LaNmgfh31n3WAwOjmBGYN7NbjsQFhTUMyA+BhO/TWd1QUl/G94XwYn1Hx+KRQ+2p3DTb9t5U9dUpiSkcWSYw+nc3Rk/RsGWDD7IA6JBLF+9Cl4GvFNsoIBfjjqaH49fDBruvdmU1rXwAWnGqR3xhY2dOl+0NtdPftD3hxzMe1z93FuUR47jxzM/7L3c9nXn1PqimTWSaczcOM6Vvfqy1MTH+eHR5/ks6y6x9366s/jmHbeZSSPv4HrUtszdL5Vi/pb9448s8VqHhvTPpFn+ndh2UmjmTtsJFMutpru3h/Si0uXbWTy4d1IPv8cHrn+Nk4/9STaTXiIc/v3pO0991Lq9ZIQ4USA7aVuMkvKuHCp1Wz2aO9Udpe5uahjEv1io3k5Yw/Hto1nRGIcuW4PmSVlLMsv5spOyQjwxuTX6PHGNAb9+WbajbNiqLjgP9CzE6e2a8M3e/fz2KadPNU3jQs7JpH53PMsT0hiyLgrGRAfU7k+VE1w+z3lZJSUkeP2sCq/mPFdUnCIUOb1MjUzm2tS2xPj9J/MPs/KZUFuAY/0Tq0cvcBrDBklZex1e/hm7366xURx25ptHBEfw85SN9luD58O7cPwNrGICLtK3RwW5arcZ1G5F5cILoewMLeAonIvc7LzKPF6mTigW43yT01uQ7RPfFuLS2nviiDL7SFChCuWb+Ta1PZcn+b/7spB81aR7VOruqlLCg/26oxThJ2lZfx72x7+3qszb+/cxxHxMQxPjKuy/ZysXOZk59WI7WBpgmiikjVr2HzRxfWveJAyOhyGJyKCRf2P4KK5X/KfC6/g3dPPDXg56tAQ4fFUacpsiuPaxvNzbkGVefFOB8ckxvPNvobfMtwlOtJvs9rFHZOYtScHj8/l46FenVmYW8CXew/s3yXCsDax9IuL5r879pJgvOTLgYvyuSmJnJvSlht/29qgeK7u3I43d+yt/PxU3zSGJcZx6q9rK+fFOIRib9Xr2k1dUli6v4gFeYX8vnM73vDZx/09O/Ht3v0syCv0W+aDPTvx6/5Cbu3akc+ycjkvpS05nnLGrdhUY90rOyXTPy6aGTv3kV5YtUlxfFoKr2Rm8fnQPty1NoPf7OVzhvXlqDaNrxVpgmii0s2b/T530JzKIiJwer04vV7y4uKJKy4mu20SXoeDRf2PYODm9QCs69qDdnm5bOmUytC1q5k3eDi9M7fSO2ML848YSnbbJEauXEry/lx2tu9IdtskostK2ZTalYKYONwRTtZ260lEeTkjflvB6p592dS5C8YhJBQV4nK76Zy9h3lDRnDv6y/z6oVXkN02uUqsx6xcSrnTwdC1q3nloqs4ct0a/vXi4zx/xXXMGTWa+KIC4ouK6LEjg/lHDmPkyiWk7dnFL4cPZlsn6wHFuOIiCu2hTMTr5cwFP9Apew/Tzr/c7/k558dvuOKrz7j6kecO6rweuX4NK/pYTYgutxu3y1XPFkq1PB0jI1g+qkFPA9SgCaKJSjdtZtOYqgmi3fV/rHLvvmoe5Q4HexPb0j43B4ef/3sVc7wOBw6vF68InogIIjweCmLj+Hb4cRy+eQN9t21CgHIRtnc4jK67d1Zu745wEelx4xVhSb9BtMvLocvunZRERRFXXMTyPgNIyd1HfFEhxVHRuDwe2hbsJ71bL7rvzCSupLgynoKYWGJKS9iV3J744mK+GXEcQ9NXURgTS3r3npy49FcKYmLJTWjD3KEjSd6fi9fhoOf2bfw4ZAQ5bRJJ27OL0Yvm43U4mHDDHVz/yUy+GXEcO1I68uBrL5LTJpHPR51Cl907iD3+eMpKyyiPiyN//Xq+6juIrju3s61TKqf+Mo+5I47jsAgnd3RI4KP8UhJ27uB7VxzHrVxCSZ8+LEvuQLvSEkoN7HM4ueybz0nds4unr/4TSe4yclz+28gTCgvIj4uvMk/s83lMYhxvH9mTC5duwCXChQ43/5fX+M7i1Jxstie1rzG/rr65cHdL1w482Kvmmy8bQhNEE3mLi1l71NAq81Kff57td9wR8LKUak7RRxxBSQNfn9tmzNm0u+GGyubWjn9/iN2PPFq5fOthnem2a0flZ38PJhYu/IVt9oOT/X9bTfrhB54L6vHJx+TvyWbJg3/nuAkPkjB6NJ7sbDadcy6HPfwwuR98QP68eUSlpSH/eIyi++9j4OzPkcjIyoc2B6SvoWTtWjL+fAtd3nuX8gSrn6AwI5NtZ51Fp8lTaHvCqMqbAADyPeVklXnoGRtlxegpJ89TTkG5l7LduxnQri27o2LI8ZTj9hpW5uYzOD6GQW3j2VlYTHRUFO0iI9jvKeefm3ZyfVp7fsopYGB8DNFOBwPirNvIyw14jGFdUQmDE2IpKveSVeZm6f4iDo+PoXdsFA4RssrcFJV7iXU62FZcRrvICGIcDtq6nGwuLqVXTDQRAgvzCkmxl3WKctUcSbqBNEEEwN6p04jq1ZPogQPJm/UpiRecz/pRxwelLKXCQdfXXyf26BEUL1mCt6gI43ZXeb6m6+vT2HbtH6psEzdqFIU+b2Xs9Pjj7Lz//lrLSH1xIp4dOyofyHQmJVGekwNA56efpnT9euJPOhH39u3suOtuogcOpPv77yEieIuK2PPc83T4yx1sOO10kq8eR/ubbqrct7esjLVHDrbKefYZYoYNx9WxA+tGHU95Xh7JV1/NvmnT6L9iORLZuLuPSjdtIrJHj3ov7sbrpfCnn4gdOhSJjW10MvBHE0SQFC1ZWu/zAEqp0Eu8+GLyPvwQAGf79kT37UPhz/OthU4n2GOM9V+5gp0TJpD3wYd+9+M7NI2vNmPOJvXZZ/1sUbuK60fSVVeRNO53RPXsyf6vviIiOZnYYcOqrLvvv/+tMipB/5UrcO/cSWTXpt8RGcwEEZhbJlqp2KFHEdmrF660VFKffhpEWDd8RKjDUkpVU5EcwBogszDbZ2w0nwEo04+o+WCor9qeido/ew4JZ5yJIyGe6P79WX/CiXR76y0ccXFsvuACABJOP42oAQMQl4u2F11E4TzrXTE506eTM306aS//m+233gZA2uRJJJx8cuX+y7Zuq1LezgcfJO+TWfT+/ntcHauOvNCSHNI1CH8K588n4083Ysr0kX6lVNN1mz6dfW++Qf6cL2os6/HJJ0T36+tnq4bTwfqaUdyxx9Jv2VL6r15FG30HgVKqibZedZXf5FAh94MPKVpS/7D6oXBINzHVRg5iKAOllGqsiuYraJnD2euVsC72jQZRfXoD0H/VSpKuuhKA1IkvVK7Wb8VyOj70oN9dOJOSiOrTp85iIlJSqvzniPdpu4w+su42VaWUChatQdQh+ZpryP/6G7pOnUqE/ba7jg89RMf776+8Fa/Lf17BERlJ7FFH1dje96KfPXkK5fn72fdazVectru+6kjoXSZPwpOTA14vEe3a+e1Y6/7eu+TMnEn53n10mfQyBd9/T/aUVyiu5y1zzsTEKiOkKqVUbbSTOoAqLuQpt99G1gsT/VYZfd/k1m/5MjxZ2bhSOyMitb45rWJ+r6+/IqJDBzw7dxLZzf8AX+7de9h07rkkX3tNlZcOdX7yCVydO7Pr0X9Qum5d5fy+CxdQ+NNPtBkzpkoiSrnjDkrWptfZdtplymQiOnYkskcP1g6u+kIkZ/v2xB13bJWXFFXG8vTT7Ljzzlr3q9ShqLFNTNpJ3UpEHW5dYNvfdFOtv2zfB2QcUVFEph0YzbLnnNl0m/52jW36zP+ZPvN+JDItDUdkZK3JAcDVsQP9fv2FlD//mb6//kK/pUvoOXs2iRdcQOyIEcSfWnUoA2diIm3sYUiS7GdCUp99hnY3XE/ac88xIH0N3We+Q9f//rfqdu3bE3/SSUT374/YA8xF9uxZuTzp8stJfeopen399YHjjY1lQPoaEs89hw533+03/k7/eNTvfIDYY46p8jnhrLPoM+9HXKmptW7TUGkvv0xvn/eHK6U0QQRUtzfeoOfs2fWu12bM2X7nR/XoQezQoTXmRyQlEdG+5tg19XEmJOCIiSGqZ4/KeSm33kqf+T/7Xb/jgw/Qb+kS2owZg9hvugPrfdlxxxzNgPQ19F/zG93fnUnPjz+qXC5OJ2mTJ9Htv6+TeOkl1ky7oz8yLZW4E6w376U+d+BBpPgT/D/FHtmjR5XPEZ06kXLH7fRbuoSu06bSf/Uqen/zNUm/v5rDHnyAiPbt6f3N1373VRlfbCw9PztQk+kyZTI9Pqr6IFXCKaNxde5M/Ck1xwKKSEmpERdA7PDhdSY0f6IH19+n5Eys+S4OpUJBm5gOUbU1ZzVV3iefsOOee+n8r3+ReF7dQ58Xr1pNRLtkStasQSIiyBj/J/r+spDCefNwdelKZLeuONu0aVC5FcfTb+kSTFkZG045FYmJoTw7m06P/YO2l1xC6ebNuDp1whEdXWO76ufBlJeTPnBQjWX7v/qK7bfeRpcpk4k/6aTK+ZvHjqVk9W+kPvcspqiI/XO+oGDu3Mrlvb7+moLvviOyZw8y/mi9C739LbcQM/hIMm4YX7le/9WrEKeTgnk/kXH9gXemS2Rkrc/mRA8aRMmqVQ06T6rlaolNTJogDlGbLx9LyYoVAU8QxhiKly4j5qghAR1vpj5r+g+gzTnnkPrM01Yc5eUgUu8ty3UlSk9ODo7YWBxRUY2KqWIohtTnn6fNWQfeTFiydi15s2bR4a9/RZxOtvxuHMWLF1cmhwq7HnmUnOnTOezRR2h7wQXsuPc+9ts11E6P/YOdD9hv9vttNVvHXU3xkiX0+PADNl9s1eIOe/hhYocPI7J7d0xZGWUZGUR2707+nDlEDRjA5vMvoKGiDh+AOCMoWbmSxEsurnUoC189Z3/OpjHnVH5uf8stZL/0Uh1bHNo0QdRDE0TzMeXlYExl/0Fr19CEUF2walIVPDk5RCQl1bmOKSvDk5uLq0PDh1zwFhaydthwYoYMofs7MzBlZZSkpxNz5JENPqayrVvxlpZWJoqKcYokNpaI5GTcmZl0f3cm3qIi4kaOxBhD2caNRPXuXePOuoQzziD/yy/pOm0qxuslftQoALInTyZ+9CmVTwv7btd12lS2/eG6OmPs8+MP4HSy4bTTMUU13xfv6toV97ZtfrZsfVpigtA+iEOUOJ1hkxzAPp4W+IBjfckBrOajg0kOAI64OLq8+ipdJk+q3EeM/cxM/Mkn42zXrt59RHbrRnTfvvRbuoS+C6yB7/otW0rf+T+D/cXRmZRE3MiRVhkiRPXuXWUfSb+/mva33ELaxBcYkL6GuGOPrUwOAO1vvLHKUBKpzz4DWE1uccceS+qLE0kaN46E008j8ZIDb32MPfpo+i1fRkRKChHJyfT9+Sf6Lvr1wH4mvkDHvz9E7y//R6cn/un3+Pr8+EPldK8v5pD28r/pcNedJJx1lt/1u731ZpXPcccdB0DyNb8HoPu7M/1uB9Dx/vvp/v77ALQdOxaasfYcTFqDUIe0rH//m8iuXUnUYVWqyHlnJrsmTKDfksU4Ymu+DrOxNbb6mLIya7+1vBmwZO06nIltcB122IFtjKHol1/JuOmmKrUM35Fbq387z587l8wbb6qyLkDxihXsuOtuOj/7DK6OHcn98CPa3XB9ZXOp8XjYP2cOO+66mz4//kBESgpFS5cSM6TuJlVvYSEFP84jfvTJGLebnLfexhEbQ9LVV5P/xRdE9upFdN/GjcmkTUxKKdUAJenpbL7wIqL69qXnrE/qbG4r3bQZU1KMIzaWyO7dmznSwNHhvpVSqgGi+/en5+efEWE32fX6Yg7G4/G7ru/t38o/TRBKqbAS1atX5XRrrhm0BC2vV08ppVSLoAlCKaWUX5oglFJK+aUJQimllF+aIJRSSvmlCUIppZRfmiCUUkr5pQlCKaWUX5oglFJK+aUJQimllF9BTRAicpaIrBWRDSJybzDLUkopFVhBSxAi4gT+DZwNHA5cKSKHB6s8pZRSgRXMGsTRwAZjzCZjTBnwDtDwdxwqpZQKqWCO5poKZPh8zgSOqb6SiIwHKt7aXiAiaxtZXnsgu5HbNgeNr2k0vsZrybGBxtdU/YK145AP922MeQV4pan7EZFFwXppRiBofE2j8TVeS44NNL6mEpGgvWUtmE1M24EuPp/T7HlKKaVagWAmiF+BPiLSQ0QigSuAWUEsTymlVAAFrYnJGOMRkVuA/wFOYKoxZnWwyiMAzVRBpvE1jcbXeC05NtD4mipo8YkxJlj7Vkop1Yrpk9RKKaX80gShlFLKr1afIEI1nIeIdBGR70TkNxFZLSK32/OTReQrEVlv/5tkzxcRmWjHuUJEhvrs6xp7/fUick2A43SKyFIR+cz+3ENEFtpxzLRvIEBEouzPG+zl3X32cZ89f62InBnA2NqKyPsiki4ia0Tk2JZ0/kTkL/bvdpWIzBCR6FCePxGZKiJ7RGSVz7yAnS8RGSYiK+1tJoqIBCC+f9m/3xUi8pGItK3vvNT2N13buW9KfD7L/iYiRkTa259bxPmz599qn8PVIvKUz/zgnz9jTKv9wer83gj0BCKB5cDhzVR2J2CoPZ0ArMMaUuQp4F57/r3Ak/b0GGAOIMBIYKE9PxnYZP+bZE8nBTDOvwLTgc/sz+8CV9jTk4Gb7Ombgcn29BXATHv6cPu8RgE97PPtDFBs/wWut6cjgbYt5fxhPei5GYjxOW/XhvL8AScCQ4FVPvMCdr6AX+x1xd727ADEdwYQYU8/6ROf3/NCHX/TtZ37psRnz++CdTPNVqB9Czt/o4GvgSj7c4fmPH9Bv5AG8wc4Fvifz+f7gPtCFMsnwOnAWqCTPa8TsNaengJc6bP+Wnv5lcAUn/lV1mtiTGnAN8ApwGf2f9xsnz/YyvNn/4Eca09H2OtJ9XPqu14TY0vEugBLtfkt4vxxYCSAZPt8fAacGerzB3SvdgEJyPmyl6X7zK+yXmPjq7bsIuBte9rveaGWv+m6/u82NT7gfWAwsIUDCaJFnD+si/ppftZrlvPX2puY/A3nkdrcQdjNCUcBC4GOxpid9qJdQEd7urZYg3kMzwN3A177czsg1xjj8VNWZRz28jx7/WDF1wPIAqaJ1QT2qojE0ULOnzFmO/A0sA3YiXU+FtNyzl+FQJ2vVHs6WHECXIf1zbox8dX1f7fRROQCYLsxZnm1RS3l/PUFTrCbhr4XkRGNjK9R56+1J4iQE5F44APgDmPMft9lxkrVIbmPWETOBfYYYxaHovwGiMCqTk8yxhwFFGI1kVQK8flLwhpcsgfQGYgDzgpFLA0VyvNVHxF5APAAb4c6lgoiEgvcD/w91LHUIQKrFjsSuAt492D7NpqitSeIkA7nISIurOTwtjHmQ3v2bhHpZC/vBOypJ9ZgHcMo4HwR2YI1ku4pwAtAWxGpeEDSt6zKOOzlicDeIMaXCWQaYxban9/HShgt5fydBmw2xmQZY9zAh1jntKWcvwqBOl/b7emAxyki1wLnAr+zk1hj4ttL7ee+sXphfQFYbv+dpAFLROSwRsQXrPOXCXxoLL9gtQa0b0R8jTt/B9tG1pJ+sLLrJqxfckWHzMBmKluAN4Dnq83/F1U7DZ+yp8+haqfXL/b8ZKy2+CT7ZzOQHOBYT+ZAJ/V7VO2outme/jNVO1nftacHUrUzbBOB66T+EehnT0+wz12LOH9YIw+vBmLtMv8L3Brq80fNNuqAnS9qdrKOCUB8ZwG/ASnV1vN7Xqjjb7q2c9+U+Kot28KBPoiWcv5uBB6xp/tiNR9Jc52/gF2EQvWDdbfBOqye+weasdzjsarzK4Bl9s8YrLa+b4D1WHcfVPznEawXKG0EVgLDffZ1HbDB/vlDEGI9mQMJoqf9H3mD/R+m4u6IaPvzBnt5T5/tH7DjXstB3plRT1xDgEX2OfzY/oNrMecPeBhIB1YBb9p/jCE7f8AMrP4QN9Y3yz8G8nwBw+1j3Qi8RLUbCBoZ3wasi1rF38jk+s4LtfxN13bumxJfteVbOJAgWsr5iwTesve7BDilOc+fDrWhlFLKr9beB6GUUipINEEopZTySxOEUkopvzRBKKWU8ksThFJKKb80QagWS0Taicgy+2eXiGz3+VznSJQiMlxEJjagjJ8DF3GNfbcVkZuDtX+lgk1vc1WtgohMAAqMMU/7zIswB8aWaXHsMbo+M8YMCnEoSjWK1iBUqyIir4vIZBFZCDwlIkeLyHx7wL+fRaSfvd7JcuAdGBPssfbnisgmEbnNZ38FPuvPlQPvp3i7YswbERljz1tsj/P/mZ+4BorIL3btZoWI9AGeAHrZ8/5lr3eXiPxqr/OwPa+7T5lr7Bhi7WVPiPXOkRUi8nT1cpUKpoj6V1GqxUkDjjPGlItIG+AEY4xHRE4DHgcu8bNNf6yx9ROAtSIyyVhjLPk6CmsIgx3AT8AoEVmENaTzicaYzSIyo5aYbgReMMa8bTd/ObGGvhhkjBkCICJnAH2Ao7Ge1J0lIidijRjbD+vJ3p9EZCpws4hMwxoiu78xxojPy3aUag5ag1Ct0XvGmHJ7OhF4z34L13NYF3h/PjfGlBpjsrEGtOvoZ51fjDGZxhgv1rAQ3bESyyZjzGZ7ndoSxHzgfhG5B+hmjCn2s84Z9s9SrGET+mMlDIAMY8xP9vRbWEO55AElwGsicjFQVEvZSgWFJgjVGhX6TD8KfGe385+HNSaSP6U+0+X4rz03ZB2/jDHTgfOBYmC2iJziZzUB/mmMGWL/9DbGvFaxi5q7NB6s2sb7WKOhftHQeJQKBE0QqrVL5MCwxdcGYf9rgZ5y4B3TY/2tJCI9sWoaE7HeLngkkI/VpFXhf8B19jtEEJFUEelgL+sqIsfa01cB8+z1Eo0xs4G/YL31TKlmowlCtXZPAf8UkaUEoU/Nbiq6GfhCRBZjXfTz/Kx6ObBKRJYBg4A3jDF7gZ9EZJWI/MsY8yXW+8Hni8hKrJpBRQJZC/xZRNZgjWo7yV72mYisAOZhvV9cqWajt7kqVQ8RiTfGFNh3Nf0bWG+MeS6A+++O3g6rWiCtQShVvxvsmsFqrCatKSGOR6lmoTUIpZRSfmkNQimllF+aIJRSSvmlCUIppZRfmiCUUkr5pQlCKaWUX/8PzL7S6cvqpWAAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 350 + }, + "id": "3iZTVn5WQFpX", + "outputId": "a2d5e118-559d-45c6-b644-6792af54663d" + }, + "source": [ + "del model\n", + "model = NeuralNet(tr_set.dataset.dim).to(device)\n", + "ckpt = torch.load(config['save_path'], map_location='cpu') # Load your best model\n", + "model.load_state_dict(ckpt)\n", + "plot_pred(dv_set, model, device) # Show prediction on the validation set" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAFNCAYAAACE8D3EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydd3iUZdaH7zPpgTRqaEFBFBUDKtZ16YpYABEUdRUVxbq2Za2rsq5tv8W69rVhb4uKCioIAXVBBMRIFQSMlJAAaYT0PN8fZ8ZMYgITzKSe+7rmmvaWZ6L+PM+p4pzDMAzDCAxPQy/AMAyjKWGiaRiGUQtMNA3DMGqBiaZhGEYtMNE0DMOoBSaahmEYtcBE0wgIETlARJyIhDbAvTeJyLD6vm99U/VvLCKzRGTCflwnSUR2i0hI3a/SMNFsRIjIeBH5RkTyRSTD+/pqEZGGXtve8P4H6nuUi0iB3/sLanmtl0Xk3mCt9fciIheLSJn3t+WKyHIROSMY93LOjXDOTQtgTZX+p+KcS3POtXbOlQVjXS0dE81Ggoj8BXgM+BeQCHQErgT+AITXcE6jsCS8/4G2ds61BtKAM/0+e913XENYqUFiofe3xgMvAO+ISELVg5rR7zX8MNFsBIhIHHAPcLVz7j3nXJ5TvnPOXeCcK/Ie97KIPC0iM0UkHxgsIoeKSIqIZIvIShEZ6XfdFBG5zO/9xSLyld97JyJXisg67/lP+qxaEQkRkakiskNENgCn78fvGiQim0XkFhFJB16quga/dRwkIpOAC4CbvZbcR36H9RORVBHJEZG3RSSymvtFeH9HH7/P2nst3w5Vjj1IROZ7r7dDRN6u7e9zzpUDLwJRQE8RmSIi74nIayKSC1wsInEi8oKIbBORLSJyr+9/dvv6G1fzz+9yEVktInkiskpEjhKRV4Ek4CPv3+zmarb5nUVkhojsEpH1InK53zWniMg7IvKK97orRaR/bf8WLQkTzcbBCUAE8GEAx54P3AfEAN8AHwGfAx2APwOvi8ghtbj3GcAxQDJwDjDc+/nl3u+OBPoDY2txTX8SgTZAd2DS3g50zj0HvA78n9dKPdPv63OAU4EDvWu9uJrzi4DpwHlVzpvvnMuocvg/0L9bAtAV+HfgP0nxitJlwG5gnffjUcB7qBX6OvAyUAochP4tT/GeA7X4G4vIOGAKcBEQC4wEdjrnLqSydf9/1Zz+FrAZ6Oy9x/0iMsTv+5HeY+KBGcATAf4JWiQmmo2DdsAO51yp7wMR+Z/XaioQkQF+x37onPvaa+X0A1oDDzrnip1zc4GPqSwa++JB51y2cy4NmOe9JqjYPOqc+8U5twt4YD9/Wzlwt3OuyDlXsJ/XAHjcObfVu5aP/NZZlTeA8X7vz/d+VpUSVMg7O+cKnXNfVXNMTRwvItlAOvq3Pss5l+P9bqFz7gPvP59Y4DTgBudcvle4H/FbX23+xpeh/zP51rsLWe+c+3lfCxWRbqiL5xbv71wOPI+Kr4+vnHMzvT7QV4G+Af4dWiQmmo2DnUA7fx+Yc+5E51y89zv/f06/+L3uDPzi/Q/Ux89Al1rcO93v9R5UhH+9dpXr7g+ZzrnC/TzXn5rWWZV5QLSIHCciB6Di+n41x90MCLDYuyW9tBZrWeSci3fOtXPOHe+cm+P3nf/frDsQBmzz/g8wG3gW3RVA7f7G3YCfarFGH52BXc65vCr38f93pOrfNtL8sTVjf5jGwUKgCN3a/Xcfx/q3pdoKdBMRj59wJgE/el/nA9F+xyfWYk3b0P9QfSTV4lx/qrbRqrQmEam6pt/Vdss5VyYi76AW4Hbg4yqC4TsuHd0eIyInAXNEZIFzbv3vuT+V1/8L+s+1nf8uwo/a/I1/AXoGcM+qbAXaiEiM398hCdiyl3OMvWCWZiPAOZcN/B14SkTGikiMiHhEpB/Qai+nfoNaBjeLSJiIDALORP1TAMuBMSISLSIHARNrsax3gOtEpKs3MnxrLX9WTXwPHC4i/bzBnClVvt8O9Pid93gDOBcNKlW3NUdExolIV+/bLFR4yqs7dn9xzm1D/aYPiUis959pTxEZ6D2kNn/j54HJInK0KAeJSHfvdzX+zZxzvwD/Ax4QkUgRSUb/PXitDn5ii8REs5HgdeDfhG4bt3sfzwK3oP/SV3dOMSqSI4AdwFPARc65Nd5DHgGKvdeahgYmAuU/wGeoyC1DAyy/G+fcj2imwBw0eFLVl/gCcJh3O/vBft7jG9Si7QzM8n3ujS7/0fv2GOAbEdmNBj+ud85t8B63UmqZX7oXLkJTxlah4vwe0Mn7XcB/Y+fcu2gA8A0gD/gADbCB+kL/5v2bTa7m9POAA1Cr833UxzynmuOMABBrQmwYhhE4ZmkahmHUgqCJptd/slhEvvdud/7u/fxlEdkoWn623Ou3MwzDaBIEM3peBAxxzu0WkTDgKxHx+Zf+6px7L4j3NgzDCApBE02nztLd3rdh3oc5UA3DaNIE1afpra1dDmQAs71RTYD7ROuIHxGRiGCuwTAMoy6pl+i5iMSjqQ5/Ritc0tE0jOeAn5xz91RzziS8tcqtWrU6unfv3kFfp2EYzZiCAsjKguJiNu3pwM6CaFqxtHC3c1G1uUy9pRyJyF3AHufcVL/PBgGTnXN77UfYv39/t2TJkiCv0DCMZktqKkydSklsWy5KuYS3ViZzT8IjvJR1U+YG5zrs+wIVBDN63t5rYSIiUcDJwBoR6eT9TIDRwIpgrcEwDAOA6dMpjm3H+LmX89bKZP45bDZ39nqLqMplxgERzOh5J2Cat3egB3jHOfexiMwVkfZos4TlaKNdwzCMuiM1FaZPh7Q0SEqiaOkKxq29l4/W9eaR4Z9yw/GLYFs/QhYvrrUGBjN6nor2Caz6+ZBqDjcMw6gbvFtxEhKga1cKMnczZv4NfJrXm6dO+5irjvG6+iIjyYPc2l7euhwZhtG8mD5dBTMhgfziMEbNupi5eQfwfNcpTDwoC8rjICcHsrLIgMzaXt7KKA3DaF6kpUFcHHlF4Zz2xgXM23QAL498n4lHL1cx3bxZnydPZg/UujG2WZqGYTQvkpLISS9gxCeXs3hLF14fM53xXb6EhH4wZcrvvryJpmEYzYqsYeMYPjqS77I68/bZ73B254WanzmxNu1ka8ZE0zCMZsOOHXDynw9nVW450895izOj5kBCkgpmcnKd3MNE0zCMZkFGBgwbBj/+CB/O8HDqqeejc/XqFhNNwzCaPNu2wdChsGkTfPKJvg4WJpqGYTRpNm+GIUNg61aYNQsGDtz3Ob8HE03DMJosmzapYO7cCZ9/DieeGPx7mmgahtEk+eknFczcXJg9G449tn7ua6JpGEaTY+1a9VsWFMAXX8BRRwV4YpWa9GioVVs4MNE0DKOJsWqVWpjl5ZCSAkccEeCJVWrSycqiE3Ss7f2tjNIwjCZDaioMGgQitRRMqFSTjscDCQmUQllt12CiaRhGk2DZMhg8GMLDYf58OOywWl7AW5PuT7mJpmEYzZHFi9WH2bo1LFgABx+8HxdJStLuRn54IKS2lzHRNAyjUfP111rp06aNCmaPHvt5oTFjtAY9K0sdollZhJpoGobRbEhNJeXilxk+qJBO4TuY/8xqunf/HddLTobJkyu1h9sG22t7mXobrPZ7sMFqhtGMqZIGxJgxAMz5yyxGptzIgfHZzBn1bzoV/6yiV0eNNwBEZKlzrn9tzrGUI8MwGo5q0oCYOpVZO4/lrHk3cXDbXcy56BU6tAqBrAQV1zoUzf3BRNMwjIbDPw0IICGBGRv6MG7WBA7vmMnsC1+lbbS3uXpcnFqjDYz5NA3DaDiqpAG9t+owzp55Kf2i1vLFyMcrBBM08p2U1ACLrIyJpmEYDYdfGtAbPxzB+PfGcmzHn5l97gskFGytFOkmK+tXf2dDYqJpGEbD4U0Dmva/g7jw/bM4qfNPfDbwAWJvuPQ3ke66DgLtL+bTNAyj4UhO5vke9zPp1S4M7byKDyd8QPS5f64Qx0YgklUx0TQMo8F48km49u9dOfVUmD79cKKiDq/4srpUpEYgorY9NwyjQXjkEbj2Whg5Ej74AKL8m7T5UpGysiqlIpGa2mDr9WGiaRhGvfPgg3DTTXD22fDuuxARUeWAajoSkeDN02xgTDQNw6hX7rkHbrsNzjsP3npLuxb9hmo6ElmepmEYLQrn4G9/g7vvhosugldfhdCaoirVdCSyPE3DMFoMzsHNN8N998Fll8FLL0HI3voLVdORqNnnaYpIpIgsFpHvRWSliPzd+/mBIvKNiKwXkbdFpDrj3DCMZoJzcMMNGse5+mp49ll1U+6VajoStYQ8zSJgiHNut4iEAV+JyCzgJuAR59xbIvIMMBF4OojrMAyjgSgvh2uugWeegRtvhIce0lEVAZGc3ChEsipBE02nPed2e9+GeR8OGAKc7/18GjAFE03DaHaUlcGkSfDii3DLLfDAA3sRzEaak1kdQfVpikiIiCwHMoDZwE9AtnOu1HvIZqBLMNdgGEYdkZoKU6bApZfq815yJktL4eKLVTDvuisAwWykOZnVEVTRdM6VOef6AV2BY4HegZ4rIpNEZImILMnMzAzaGg3DCIBaCFtJCVxwAbz2Gtx7L/z97/vYkjfinMzqqJfouXMuG5gHnADEi4jPLdAV2FLDOc855/o75/q3b9++PpZpGEZNBChsxcVw7rnwzjvwr3/BHXcEcO1GnJNZHcGMnrcXkXjv6yjgZGA1Kp5jvYdNAD4M1hoMw6gjAhC2wkJ1Rb7/Pjz2mAa7A6IR52RWRzCj552AaSISgorzO865j0VkFfCWiNwLfAe8EMQ1GIZRW6oLyiQl6Zbc12EdKglbQQGMHg2ffw5PPw1XXlmL+40Zo1t9UCHOydF7TZxYd7+pDrHBaoZhVOA/s8dfwEaOhBkzfvv55Mnk90zmzDMhJQWef17jRPt13waInttgNcMwfh/VzOwBYMUK3W/7C9vEieQdmMzpI3Q2+SuvwJ/+tJ/3baQ5mdVhomkYRgVpaRod98fnu6wibDk5cOop8O238MYbGgBqCVjtuWEYFQQYlNm1C4YNg6VLtbVbSxFMMNE0DMOfABpl7NgBQ4dWuCHPOqsB19sAmGgahlHBPhplbN8OgwbBmjUaFzrjjIZdbkNgPk3DMCpTQ1Bm61a1MNPS4JPHf2LIwlfh7cZfK17XmKVpGMY++eUXGDhQjc9Pn1zPkC//3mRqxesaszQNoyUTQH7kpk0wZAjszCzj83Ne4oTHntQZFUceWVFSCXqdFmBtmqVpGC2VAJpwrF8PAwZA1o5S5vzxHk5o/YN2FXYOFi6E9HQ9sBHXitc1ZmkaRkvC37LcsAE6d/5tIrvXYlyzRn2YRUUw7/zn6ReRrcfEx2vdZGSkRoQSExt1rXhdY5amYbQUqlqWGRla6eOzFuFXi3HFCo2Sl5ZqeWS/4sUVDTsOPVS7czgH2dmNan5PfWCiaRgthart3Tp00Oc1ayqOycnh+/BjGDxYv5o/H/r0oXLSe8eOcOKJFU0yG9H8nvrAtueG0RwIpOFF1RLJQw/VovGMDE1kz8lh6fo4Tv7qclrFwty50KuX99iqnYjCw+GQQ1qUWPowS9MwmjqBdlX3txbT02H1asjPhz17IDWVRflHMPTLu4lNCGXBAj/BhEY9HbK+MUvTMJo6NXUmqpoC5LMWMzPVl+nxQKtWcMQRfLXzUEa8eQkdO4Uwd24NMZ0m1IkomJilaRhNnUDHRfisxa1bNcITHw9/+AMpeUczfPZf6FKWxvyRD5GU3TKS1PcXE03DaOrUZlxEcjL06AFjx8KgQcz+pTenfXwVB4RvI6XnZXQpS2tR1T37g4mmYTR1AuhMVAmvyM5c14szZ15Fr8jNpBx6FYntyxr9JMjGgImmYTR1ahukGTOGD74/kNFvncvhkeuZ220C7bcs1+BQSormYLaQ6p79wQJBhtFc2L4dvvsOli2DlSuhXTst56mSgvTu2mTO//oIju60hU8jLiR+58/QqRO0aaOVPgsWaHcOo1rM0jSMpk5qqg4YT0mBsDAoKdGxkJ9+qu/9UpBefx3Gj4fjjxc+X9WV+FOOhfbtITq68jV9ievGbzBL0zCaOtOna4J6bCxERenryEgoLtZ5FK1bQ0YGL58zk0t/PIKBR+fz0R8fp/UN69Uy7dNH05BycjTq3q+fWqhGtZilaRhNnbQ0FbnISH1fWAgREfq8YQMUFPBcySVcsvZWhrX9jk/aX0zr/O2aCB8erjmbhx4Ko0ZpwXlkZItpvrE/mGgaRlMiNRWmTNHh4lOm6PukJM27XL9e68gLC7XSp7gYoqJ4Ytf5XLH+r5zWZiEzDrie6J2/VNSfH3mkXnfZssAi74aJpmE0GWoql4yJ0W5De/aoEHo8kJcHpaU8LDfx5/XXMyp2LtPHvE5kSV7lrXdiojbMLCpq8eWRgWI+TcNoKjz1FKxdqxZkSIh+tmePRrv79lUr85dfoKwMQkN5gNu4fctfGJcwm9cP/BthS4rVd9mmTeXrRkbC6NFquRr7xETTMJoCqakwZ44KnsejMyhA04rS01UMS0ogJgbXqjX37LqGKfk3c36bWUxr/1dCQ6IhJEx9mFlZsG4d9OypwZ+sLJg4sUF/XlPCRNMwmgLTp0Pbtvp6504N9JSUaBAoLEz9mIArLOKOort5IP86Lg59lef3XEVIXhy4Es3FHD4ccnNhyxa9RlKSCqZtxwPGRNMwmgJpaZoKtGiRWofl5eq3BO1UVFqKEw+Ti+/n4dLrmOR5nqfjbseTWwxduqjPsndv9WF26KBb8hdfbNjf1EQJWiBIRLqJyDwRWSUiK0Xkeu/nU0Rki4gs9z5OC9YaDKPZkJSkQnfIIerTLCjQzz0eKCvDiYfr3GM8XHod13qe5Jk2t+MJ9WjSukjFPB9oUfN8gkEwLc1S4C/OuWUiEgMsFZHZ3u8ecc5NDeK9DaN54euFuWqVWpY7d+qMnpAQysvhKp7mOSZxEw8zlZuRPRH6fadO6r+MjNQouy+lyOfDDKTju1GJoFmazrltzrll3td5wGqgS7DuZxjNnoIC+PFHFT3nwOOhrLiUiUVP8ZybxG3yIFOZjLhytUZLSmDXLo2079qlwaOvv4aRI1UYA+34blSiXnyaInIAcCTwDfAH4FoRuQhYglqjWfWxDsNolOzL2vOJ25o1KoDeuvDSkAgmFD3HG5zPFJnCXXIv0qqVlkKmp0NoaMVW3jn1bYaEwIwZcPDBgXd8NyoR9OR2EWkN/Be4wTmXCzwN9AT6AduAh2o4b5KILBGRJZmZmcFepmE0DPuy9lJT4brrYMkS2Ljx14YcJaXC+UUv8gbncz+3cXfYA0hkhG7Ds7w2SESEimZYmNall3stUF+/zEA7vhuVCKpoikgYKpivO+emAzjntjvnypxz5cB/gGOrO9c595xzrr9zrn/79u2DuUzDaDiqjtX1bwLsE9SMDM3HLC2F/HyKQqIZJ+/xLufwEDdxm/xTG3VER6u/s7xcLcqoKBXONm30u/x8FUWfMNam47vxK0HbnouIAC8Aq51zD/t93sk5t8379ixgRbDWYBiNEv/t+HffwbF+doNvSuTWrdrqrXNnTREqKIDQUAqLPZxd/BYz3Qj+HTGZa8ufAAmDbt1UXAsLdfvu8WhpJag/Mzxct+u9e1cIY9WxvJboHhDBtDT/AFwIDKmSXvR/IvKDiKQCg4Ebg7gGw2hcVN2Oh4drGeT27SqYCxeqeHXurCK4YoX2u8zKYk9JGCPdh8x0I3g25CqujX1F2761a6eJ7wkJuj1v1Up9mImJOjytqAh274YjjlDL09eQw8by7hdBszSdc18B1XUynRmsexpGo6dq8OXII2H+fO0yFBmpVqJzcNhhanFmZ8PGjewuj+bM0g+Yzx95kUu5JOR1aNcTundXq7RvX228kZNTYaEWFOj7rl3VAt2xQ+/rXwFkY3lrjVUEGUZ9kpamIubD12Xom290G925swpmx476/ddfk7txJ6cV/peF5X15NeQSLmgzCyReLdEePeDOO9Ui9fkpDzxQhdDjt5EsL1dr0ppy/G5MNA2jPklK0u2xz9KEii5DUPm7jh3JTuzNqSsns6Q8mbciL2Zc5/9BWbRakSLqyxw7Vh8+pkz57T0swFNnmGgaRn2yr+DL7bdrx6KiInYVteKUn54i1R3GezGXMjpmLuwuVaEMCVGxrW4shQV4goo1ITaM+qS64MvIkerrvOsurfjZuZPMLcUMXvs0K0oO4f2+Uxgd/bmmHIWEwLZt6qPs0qV669ECPEHFLE3DqEsCqeX2D774oukJCWoNipCeGcLQohlscF2YEXcRp2ydX1H2WFKiwnn44ZpCVNNYCgvwBA2zNA2jrtifWm7/aHpuLlt2xzEw/xM2FXdmZsdLOSX6Kw3ilJTAWWdp4KdTJy2DNOuxQTBL0zDqikBquataosuX//pdWkQvhqT/k+2uPZ/Fj+ekrlthU7HmVmZn6/Mhh6hY+q776KPWnaieMdE0jLrCl060fbvmWObkaM23Tzz9t+I+S3TjRmjVio3tjmHIqqfIcmHMjh3L8TGrISRRSyDDw/V8X44l/PY6U6dWWJ7W7i2omGgaRl2RlKS9K1es0Mh2bKwKZ3a2Cpn/YLS4OC1pPPxw1i3NZciWCeSXRvDFCbdy9KpvIMRb3dO/v/ou/bfiU6bUbNHC3gXV+N2YaBpGXTFmDFx4YUWn9MJCnRYZEQHjxmkqUefOWvK4cyd88AFrIvsxZPMrlISFMm/4P+nbrw3c+J/KyepVLcWqCfJQ0YTD2r0FHRNNw6grkpO1GicrS4eX+cbsRkWpYIaFaX25c7BjBytKezM081XEI6QkTeBwyoF+GuTxT1avii9Bvri4wg0QHg5HHbV3QTXqBBNNw6hL+vVTQSsqgk8/1XZs+fmaY9m9O/z8M6SlsTzqBIZlvUU4Rcztcgm9o9IhK07PveMOzcEsLq7e0hwzRo9Zv15HWISFqUj/8osmvufkWDVQELGUI8OoS8aMgZ9+0iYc+fkqaEVFal0WF0P37iwpTmbIjreJ9hSwoN3Z9O6YpdZgbq4es369toyrKW0pOVlFNTZWU5Gio2HgQJ1j7lzFHKDy8orXNeVzGrXGLE3DqGvS07WjkG9i5AEHqH8zM5OFMadwaskTtAnLY16PyzigfCdEJur3cXG63Y6JUfH0NSWG3/oki4t1hnl1TTkmT64cPbe55nWKiaZh1BW+lKI9e3SbvHmzbo1//hm6deNLGcBpG54mMTKLud0vpduR7WBLoh7jnPokFy1S69R/DEV1PsnqGn/4tuFWDRRUbHtuGHXF9Onqu8zO1lG7BQUaBCosZO6PXTk1YxpdkkKYv74r3d6eqgGfhAQVzD59tNlweDjk5Wk6ko/qfJJjxtg2vIEIyNIUke5AL+fcHBGJAkK9Y3kNo2UQSML48uWwYYMKp4iKWXk5nzGc0QVv0DN+J18s6qKtMjsnV18ldNRRGtCJiNDza+pQ5GvKYdvwemefoikilwOTgDboFMmuwDPA0OAuzTAaCdVV8lSXMJ6drT5Gj0fHTBQU8HH+YM4ufZNDIzYwu/PVtN/+GHTcSwMP3/0CEUPbhjcIgVia16ATI78BcM6tE5EOQV2VYTQmAk0Yj4/X7uve/Mz33WjOLX2OviEr+Oywm2kTtiew6hwTw0ZNID7NIudcse+NiIQCLnhLMoxGRiDzwd97r2J7npvL2zuGMi73eY4O+Z457c+nTVmmbr1943mNJksgluZ8EbkdiBKRk4GrgY+CuyzDaERUF6n+6SfYsgUuvVQnPc6bpz7I0lJeKz+fCSX/4cSQb5jZ+lxiunSEY47RuT/l5Vad08QJxNK8FcgEfgCuQKdJ/i2YizKMeiE1VZtfXHqpPtfU97JqpHrdOh2127mz+jjnzdOATUgIL0ZcxUUlzzNQFvBp2EhiusVDWZnmX6anW3VOM2CfoumcK3fO/cc5N845N9b72rbnRtOmNg2Dq46PWLtWq3BWr9aZ5bm5EBbGM3kXMDH3EU6O/IqP20ygVWmORtLDwjR3c/58tVAtLahJE0j0fCPV+DCdcz2CsiLDqA9q2w3IF5xJTVXRa9NGLch166CkhMdLruR69winR33Be+2vJnLHThXLQYMq99bs1s2CPE2cQHya/f1eRwLj0PQjw2i61KYbkH8K0IYNmrC+Z4+WSoaEMNVzM38te5CzeJ+3Yq4jvKhEyxx79VI/pm+Gua/M0WjSBLI93+n32OKcexQ4vR7WZhjBIylJrT9/qvM3Vt3GZ2Ronfgvv0BeHvdlTuKvZQ9yDu/wdsRFhOdkaspR+/b6/OGHkJKi3dzNn9ks2KdoishRfo/+InIlVrNuNHUCLUP038ZnZKiFmZGB253P3UW387eSu/lTyJu8HnMlYT26QatW2nAjNFSj6j5/ZkqKWqnmz2zyBCJ+D/m9LgU2AecEZTWGUV9ULUOMiNDgTtVBZb5tfHq6RsxDQ3F5u7nN3cc/uZVLPNP4T+xkQtq0UVE96CDt2p6To2JZXq6+z9hYbedm/swmzz5F0zk3uD4WYhj1jn9wx1cm2aFD5TLJ8HD47DPtVFRejiss4i9uKo9wI1fyNE+WX4OntLWKZEkJHHkkfPON+kcjItT/OWiQ+TObETWKpojctLcTnXMP1/1yDKMBqCmS/tRTmsCekQFZWZQXl3Idj/Ek13Idj/Oo3IiABn1KS+HQQyExUQWzoKDC4gTzZzYj9ubTjNnHY6+ISDcRmSciq0RkpYhc7/28jYjMFpF13ueEfV3LMIJK1TLJ9HQtiXznHW3xlplJeXEpV/AMT3Itk/kXj3I9Iqjv0jn1XfraufXurcEiX5qRtW1rVtRoaTrn/v47r10K/MU5t0xEYoClIjIbuBj4wjn3oIjcilYc3fI772UY+4+vTLKoCJYs0YBNeLj6Irdto6yolImel5lWfiF3cC//4E61MEE7GoWEqL8yJESv06EDHH44rFxZYcFa27ZmQyDJ7ZHAROBwNE8TAOfcpXs7zzm3DdjmfZ0nIquBLsAoYJD3sGlACiaaRkMyZgzcfrtW63ireygtheJiSmIcCYsAACAASURBVMuEi9w03nTncU/oPdxZenfFeWFh0Lq1JroffHDlwNLBB8Ott5pQNkMCiZ6/CqwBhgP3ABcAq2tzExE5ADgSbS/X0SuoAOlAx9pcyzDqnORkrdTJzNR8yvBwCA+nJLeA80te5j3G8QC3cqs8rM2FRdTCDA1Vq9Lj0UDR9OnVNyc2mhWBNOw4yDl3J5DvnJuGJrYfF+gNRKQ18F/gBudcrv933hr2auvYRWSSiCwRkSWZmZmB3s4w9o+iIh1U1qMHlJZSRARjy9/mPcbxMDdyK//UCDioD7NVKzj+eH1dVgbHHrv3+nWj2RCIaJZ4n7NFpA8QBwTUhFhEwlDBfN0552siuF1EOnm/7wRkVHeuc+4551x/51z/9u3bB3I7w9h/fBVCe/ZQsLuMs7Y/w4yyM3hCruXG8KfU+nROrcyYGLVMly5VP+agQdCpU4X/0vplNmsCEc3nvBHuO4EZwCrgn/s6SUQEeAFYXSU9aQYwwft6AvBhrVZsGMHAO698z5YsRpa/z6flJ/OcXME1oc+pSHo8mk7Ur59ao2efrULZtm1FbTnUXL9uNBsC8Wm+5JwrA+YDtels9AfgQuAHEVnu/ex24EHgHRGZCPyMVRcZjYHkZHYnHsQZhbfzZclxvNT2r0wI/wjK26g1GRqq1T6FhbodT0mBnTvVD+rLzwTLx2wBBCKaG0XkU+BtYG6gvTSdc19BRWZGFWwom9GoyM2F0z66hkUlPXi12x2cH/c57HDqpxSBAw+sKI2EiqYc6enaJ3PAAE1mr25ypNGsCGR73huYgw5Y2yQiT4jIScFdlmHUH1lfreTkwzbzTcYBvNXzb5zfZb7mamZlqThGRWnzjT17tDQyPl5PDAmBwYM1gf2bb9Sfua+haUaTJ5Da8z3AO+iWOgF4DN2qhwR5bYYRdHYuWMnJZ7ViRXYi753yHKN+fAuWbdXoeFSURtXbtNEZPzExsGiRlk3Gx2udeWKibs83b9aRGUazJ6AWbyIyEDgXOBVYgvkhjcZI1XnhVXMmq3yfMXAcw8a348ecNnw4/i1G9MqAnW11y11crBZkp05qUW7eDD17wujRvx2yZn7MFkUg/TQ3ATcAXwJHOOfOcc79N9gLM4xasa+ZP1W+37a5jEFnxbN+Zzwfj3+dEb3W63FlZeqrjIqCAw7Qip/ISA34+IQ4kD6cRrMlEEszuWpSumE0OvY188fv+y25MQx5709syYthVtx4Bq7bCjFHaepQXJz6LouKKncqCgursFz9yyWTkqyuvIURiE/TBNNo/Oxr5o/3+5/XFTPkvXPJLI7ms7bn8Yfo7yC3laYQDRigjTfS0uCoo1Q8MzNVMO+8s0IYfX04jRaJja0wmge+TkU1+RqTktiwLJvBn95MTmkrZrcdz3EhS6BY4LjjVCgXL4ZRo+Css2DFCv1s8GCrJzcqYaJpNA32FeQZM0Z9lqAWZk5OpZzJdUePZ/D9CRSURzA3ZjRHFS3RTkYxMbBxI5x2WuUI+Nix9fv7jCaDdW43Gj/+4yj8gzz+OZF78TWuXg1DJvWmVHKY12sSyRsXakQ8Pl7LIzdsgPXrtZ2bYeyDvVmavu7shwDHoDXjAGcCi4O5KMOoxFNPwdq1mgYUF6ed0RMS9PPExMrW55QpFVbpo4/yQ0R/hr49CU94KCkTX+PwpZt0HnlGhgomqM9y5Urtf2kY+2CfndtFZAFwlHMuz/t+CvBJvazOMFJTYc4cTTCPjdWI9sKFKnzLl8Ppp1e2PkeOhBkzICGB70L6c/JLFxJRtou5A6ZwiHOwdavmXvomTBYUaIlkt27mtzQCIhCfZkeg2O99MdY42Kgvpk/XTkKgNeBRUfp64UIVvqopRk88AX37snhPH4a/cgGxLpu5XS+i50/rYVu0RsR37VJfZq9eWs0THl45gGQYeyEQ0XwFWCwi73vfj0bHVBhG8ElL03ZsixZBfr7WgOfna4eNE06ofGxcHGzZwv8OvIBT37iIdp6dzEu6mO5ZqdoLs3t3Df5kZmrEvGfP3wSMDGNfBJKneZ+IzAL+6P3oEufcd8FdlmF48aUSHXwwfPmlVuyEhalALl0KmzbpZ3Fx0LkzC1qN4LTXL6JzzG7mJl5C191r1UKNjtbnTp10PvmWLdp8w5LTjVoSaMpRNJDrnHtJRNqLyIHOuY3BXJhhABWpRL5gj4j2tOzYUa3PggIV1OxsvljblTN3TKW7/MJcRtFp08aKGvJ27dRK3bZNRTYnB264wcTSqDWB1J7fjU6LvM37URjwWjAXZRi/4kslKi7Wh8/STE3VEseSEsjL49PyUzgj4wV6ejaSkjieThG79LjychXW/Hy1SouK1EcaHm7zfIz9IhBL8yx0kuQyAOfcVu8cc8Ooe2pKYh81Cn78UVODIiMr5vWEhfFRx8sYO+cKDovayOyIkbTrFA9R3pzLzZs1Yv7zzxW16SEh2tYtIqKiNt0wAiQQ0Sx2zjkRcQAi0irIazJaKjUlsY8cqelBs2aphdi1q46fKCpievSfOPfzKziy03Y+a30hCdnZEJlYcc0uXVQkMzO1e1F8vOZ5JiaqFWrzfIxaEohoviMizwLxInI5cCnwfHCXZbRIqutU9OOPcM01GsgpKtKUoZwc6NKFt8rP4U+//Itjo1cy68w3iFtYoHPICwsrUpMKC1UoO3WCvn2tD6bxuwkkej5VRE4GctHqoLucc7ODvjKj5VG1U1F6Oixbpj5J0O24xwPR0bySOYJL8h7ipOhlfNzqXGK+idWAUEaGPnyjrPLyNLXo4os16R2qrU03jEAJJBD0T+fcbOfcX51zk51zs0VknyN8DaPW+GaP+1izRgM9Irodj40FEV7IH8/FeY8zqNViZrafQMyw4+DMM6FzZ01aP+IIzeNMS9PAUdeuKqiTJ6uluXmzzfMx9ptABqudXM1nI+p6IYbxm67oGRkqlh6PPkdE8HTon7ms6EmGh6fwcavxtDqxrwqix6NCGBenDThKSzWZffBgDfj4OiBNmQIvvqjPJpjGflCjaIrIVSLyA9BbRFL9HhuBH+pviUaLwZde5LMGO3TQgWbh4VBYyGM5F3N13j85M/wzPjhrGlGd4nUWuY/t2+GHHzRS3qaNfrZokfpCExLUZ2oYv5O9+TTfAGYBDwD+7V/ynHO7groqo3mxr16YNdG7t1buHH00/7fgeG7Jv4sx4R/z5pD/EB4eB8cfr9t5X3Bn9eqKMRVpaZqaFBOj2/wBAyxSbtQJNVqazrkc59wmdGTvLufcz865n4FSETmuvhZoNHH2NfBsb8dGRIBz/GPHldySfxfj23/BW8NfIvy4I9Uivfrqytv5tDTd0kdFaZpRSYlan+npFik36oxAUo6eBo7ye7+7ms8Mo3r2NfBsL8e6+ATuyvgz9347nAsvhJdeGkpIyNCK41NTNRVp1iwN/OTmqmC2bw87d6ofVEQj6BYpN+qIQAJB4pwvfwOcc+XYmAwjUNLSNDjjj//As+qO3b4dNy+FW5/pzr3fDmdi0mxeekmNx1/xWaW5uboNT0zUA0pL1dr0+TSLi6FVK4uUG3VGIOK3QUSuQ61LgKuBDcFbktGsCA+Hzz6r3HXd112oKklJsG4d7ocV3JhxG49ljOeqtu/wRNz9eFa+Uln0fFbp99+rdRkVpSlJJSVac56Vpffq3Fmj6yaYRh0RiKV5JXAisAXYDBwHTArmooxmQmqqBnJyc1XI9uyB+fPhp580GJSaqqk/o0fDoEEwfz7lc+ZyTdotPLZ9PNeHP82T2RfgWbMKhg+H996ruLbPKs3JUUsT1NosLdXSycRErQAKDdV7GUYdEUhFUAYwvh7WYjQ3pk+HHj00qLN6tQpcbKyOlgDdXpeWal6lx0NZGVyx+yFe2HM+N4dM5cHivyIi4AmFHTvg+uv1vLFjK/psxsVptDwqSgWyRw/1Y4JaotYr06hjxM9dWfkLkZudc/8nIv8GfnOQc+66vV5Y5EXgDCDDOdfH+9kU4HIg03vY7c65mftaZP/+/d2SJUv2dZjRGPBPL/ruOzj2WK379lFerjmYERFaIvnzzxASQlliFy75eQqv7jyNOyP/xd8Lb0FCPOqnLC/Xc2NjtdonJaXCp1laqjPKPR497ogj9BzzYRoBICJLnXP9a3PO3izN1d7n/VWrl4En0HEZ/jzinJu6n9c0GjOpqXD77WpVbt+uzTJWr4Y//hH+8Ac9JidHBXP27F+DNSXlIVy08jbeKjqNe9o+xp35dwFOxc/3P/WwMM3B3LJF3/uP7N2zB7KztTFHr16B54Eaxn6wt2mUH3mf92sekHNugYgcsH/LMpokTz2l1qMv3ccndPPn6za6SxfdUkdH/zosrTgihvMyHmV60en8M+Zebu4zB76P1PPKyirKKEtL9X1JiYpzcnLFwzDqkb2VUX4kIjNqevyOe17rLcd8UURqHAEoIpNEZImILMnMzKzpMKMxsWiR5kR6PCpwZWVqLZaVwVdfVTTJKC6Gfv0o2lPG2KznmF54Oo+0vpObQx5SYW3fXh++HCNfx3aPR6dHWsd1owHZW/R8KvAQsBEoAP7jfewGftrP+z0N9AT6Adu8168W59xzzrn+zrn+7du338/bGfWKiApcSYn6F30BGdDUI1+TjKQkCkJjGJU9jY/yBvNUhyncEPui1pr36gX3368150lJeh3n9PxBg3SKpNWRGw3I3rbn8wFE5KEqjtKPRGS//JzOue2+1yLyH+Dj/bmO0Ug5/ngNyvgoLa0I4uzYoaI5Zgz5p57NyLM8zEs/lOfP+ICJPbIga2jl4M3BB6swvv665loeeqimEUHNyfGGUQ8EktzeSkR6OOc2AIjIgcB+jbwQkU7OuW3et2cBK/Z2vNFE8EXMt23TLXRJiW6nffgsznXryHvgCc5Y9zBfZbRi2uj3uTDhY0ioZoyuv78yK8s6rhuNhkBE80YgRUQ2AAJ0B67Y10ki8iYwCGgnIpuBu4FBItIPTWHaFMh1jEaO/1yf5GRYvx7Wrq3YVsOvvsmcOd8yIu9tFhdE8fqDaYy/ZQywj8Rz3whfsI7rRqMgkOT2T0WkF9Db+9Ea51xRAOedV83HL9RyfUZjp2pDjkMP1bk+PsH0eECErNIYhme+zneuD293uYmzV2ZBapVcyppayPlSi3yfW8K60YDsUzRFJBq4CejunLtcRHqJyCHOOfNHGpXn+qSnay24SIVolpezo7wNJ/MZq9yhTG9zOWcetAkS+lbudFTTJEqfn9NE0mgkBLI9fwlYCpzgfb8FeBcL4hipqVoCuWiRRr4zMrTap7RUv/d4yHDtGeY+50cO5sPYizjVzYbe5/w2mFObFnKG0YAE0rCjp3Pu/4ASAOfcHtS3abRkfJZhly6agL5tG6xbp3Xg3qmR28o7MMjNZT0H8UnIKE5t/ZXWhicm/jaYU5sWcobRgAQimsUiEoW3/lxEegL79GkazZzp0zVCvnYtbN0KGzeqhSkCERFsjujJQOaTRhKzwkczNH6p1o4ffXRFt3X/7kNVJ1GCRcmNRkkg2/O7gU+BbiLyOvAH4OJgLspoZFQXoFm+HFauVPGLiFBr07st31TShSGln7GTNnzOcE5030LrztoMePt26Nfvt8Eci5IbTYS9iqaIeIAENC/keHRbfr1zbkc9rM1oDNQUoNm8WftkRkRojXl4OBQX85PrwZDSOeS6GObIKRwjSyDpQDj//AohrK6hhkXJjSbCXkXTOVfubRH3DvBJPa3JaEzUFKDJzdWSydDQX8sc13oOZWjJLAqIYm7YqRxZvhQiIjXtKCOjoqKnpuCORcmNJkAg2/M5IjIZeBvI931oY3ybEVW33336aDmkf09Mf+LidKsdHq6WY1ERq9yhDCl9h3KEFAZxRMh6bR4cEaHHrVmjomnBHaOJE4honut9vsbvMwf0qPvlGPVO1e33jz/CK6/ACSdAz57qt1ywQJtldOyoPsllyzRKXl4O4eGkevoxbONzhLhSUjyncFibdAiJ06BQeTns3q1beLDgjtHk2Wf03Dl3YDUPE8zmgv/22+PRSHhsrDb79XjgyCP1uGXLNK0oJUW35gMHwjHHsKysL4M3vkC4lDA/bhSHddihlmVISEWSe26uXrO6qLlhNDECqQiKRCdQnoRamF8CzzjnCoO8NqM+qFrRs2aNCl16ekVnoQED4JtvYPFiFb8jj4TERBZv6cLwjAuIDc1m3hVv0+OHaNjZTgW3pESj6dHRem2fMFtwx2jiBLI9fwXIA/7tfX8+8CowLliLMuoR34CyoiJYuFAtxIICFb3XX4fWrXXaY1KSjpPo2BHWrOHrLwoZsWEy7UN3MTfqNLovjoLu3dVSLS+vaAknoj7Se+4xsTSaBYGIZh/n3GF+7+eJyKpgLcioZ3z5kWvXatAmMhIyM3W6Y0lJRR5mly7w5ZdQUECKG8gZhU/QRbbyRdhIunYWyC2BH37QaHpEhG7te/TQZsLh4VYOaTQbAqkIWiYix/veiMhx7P+wNaOx4cuPzM7Wrfr27RoZ9/XDDAnR0RM//giFhcwpOJHTCv9Ld34mxTOUriUb1ZocOFC37tu3Q+/ecO65cMYZaplaxNxoRgRiaR4N/E9EfP/WJwFrReQHwDnnzHxoDkRGqnW5bZtaizt36ta8dWuNfmdnM6t4CGfxPgezjjkMpUNZJhCigigCw4frNr9fP2sabDRbAhHNU4O+CqN+qKlf5fTpFbmZod5/JUJCdDRux46wbRszyk5nHG9yOCuZ7TmVtm6HhgV988b/9z+9xvHH65YerBzSaJYE0oT45/pYiBFk9tavMi1NczJjY+Hbb7XdW3S0Vvzs2MF77mzOK5vGUSzjU8/pJLhdFf0yy8p0frmIiu6rr+rnVg5pNFMCsTSN5sDe+lX6IugdO6ofcvt2Dfqkp/NG1gguKnme4+RbZrUaS2xpPhR6BTM8XC3NuDh9+EZegImk0Wwx0WwppKVpVU5Kim6Z4+LgkEP08xtu+G2HoT17mNbrXi5ZNIkBof/jY86kdX6Obt9DQvSRkAB//CMcfvhvh58ZRjPFRLM5Up3vMjwc5s/XLXhsrOZiLligUe+qHYa2buU/7W/nikWXMjTiaz4sP5PokCIIi1T/pS+iftJJcNBBFZU+5rc0WgAmms2N6mrJL7xQrcfdu9VSjIysON43XtfXYSg1lSdP+4Rrt1zGiMh5TG91IZGecChy6uNs21Z9mCedpNt581saLQwTzeaGv+8yPV0bbohAYaGWRGZkaNJ6YqKmBhX5NeFPTeWRS3/gpi23MTJ+Pu+E/ImI3Tu1Eqh1a93ed++uAlxUBFOmNNjPNIyGwkSzqVN1K758eYXFt2aNWpWRkSqgrVrpIypKuxZlZUGnTr9e6sGbMrht6QWc3WMZb0TfRPguoCAE8vIgJkaPLSzUih/LuzRaKCaaTZnq0og2btRt9J49WtbYurX6MH3f5+frdwUFOkHyvvsgNZV7rt/J3SnDOK/9HF4Z9j6hnmM1/Sg3Vy3Ttm01Up6bq35M61RktFACKaM0GitV27r5xHPhQi2LbN1aBXLzZvVDguZVRkToa+dwn8/mbycv4u6UwVzU6r+8Gv9nQhd9pd+fcQaMHKlJ62FhKp6DBqnQmv/SaKGYpdmU8W/r5mP3bhXP+Hi1JouKVDDXrtXPo6K0wXBiIm7tj9x8dyRT90zissSPebbNFDxb09UCXbWqoi/miy+aSBqGFxPNpowvKd0/PzIzE7p1U4sQ1Je5erWOrTjggF97ZDoHNywYw+N7xnJ1p/f5d6/H8UhbEKfCu3UrDB5sUXHDqIKJZlOmurG3YWHaxs1HYqJux0Wgb19ISKB823au/mgEz24byY3hT/BQm2cRaavHt22ruZjnnGPRccOoBvNpNmV8SekJCeq3TEiAO+/UXMysLO1YNGsWfPKJJqNv2EDZmnVc9u4pPLttJLdGP85DHf+FbNmsXY2cqxBeC/QYRrUEzdIUkReBM4AM51wf72dt0KmWBwCbgHOcc1nBWkOLoLqxtwcfDE8/DTNnqtXYujWkpVEaFsUl/7uK17JO467uLzOl/5fIuniIjNC0orIyFcw777QtuWHUQDAtzZf5bVu5W4EvnHO9gC+87426JjlZrcbWrTVQlJhISXkIF3w3mdd2nca9kffy9x3XIvPmakQ8JkbF8pxztEvR2LEN/QsMo9ESNEvTObdARA6o8vEoYJD39TQgBbglWGtocfgnus+cqf7JqCiKy0MZv/5u3t8zhH/xVya7f4ML0YBPfr5u53v3ruivaRhGjdR3IKijc26b93U60LGe79988SW6l5Wpf3PXLtixg8JfMhlb+iafFAzhsbC/cF3Zo1DmUQuzvFzzOz0erfTx9dc04TSMGmmw6LlzzomIq+l7EZkETAJIspK9mruu+5g+XQXzq69gxw4oKaGASEbnvcrnDOVpruTKkmf12JAQffbNJQ8L00dCgg1AM4x9UN/R8+0i0gnA+5xR04HOueecc/2dc/3bt29fbwtslPisyKysyl3XU1MrjklLg2XLNC+zoIB8ojmdT5jNybzApVzJsxXHlpaqYIaG6sO5ikbCNgDNMPZKfYvmDGCC9/UE4MN6vn/TpLpySZ9V6CMpCX7WySR5rhUjmMV8BvIKF3EpL1W+nnO6NS8rU/EUUZ+mDUAzjH0SzJSjN9GgTzsR2QzcDTwIvCMiE4GfgXOCdf9mRXXlkoWF8OGHFdv1Pn2grIzs4mhGMJNvOYY3OJ9zeafyeb5BaKBWZuvWWikUEWGNhA0jAIIZPT+vhq+GBuuezQ6fH3PZMu2LedRR2vh3+3b4/HPtZuT7bskSdnU8lOG//Ifv6cu7jOMsPqh8PV/QJzJS04yGDNFzu3VTy9VKJg1jn1gZZWPFv+3bccfpaIqUFBgwQMfl7tqlYhcXB4WF7Fi7k5Oz3mMV3ZjOGM7gk99e0+PRrXhYmLaLO/hguPVWE0rDqAUmmo2VqtMjBw7UphuLF2tTjq5doV07ALaHdGbolqn8VNiFGR0nMXx7NYIJGjWPjYVjj4X77zexNIz9wESzsVKdHzMyUi1MP7YWtWVo6sOkFXXgk/gLGNJ9M7TqoXXnhYV6kG9yZGwsDBsGV19tgmkY+4mJZmPFv+1bero2FhaBzp1VONPS+KW0E0PSHiO9OIFPu17OH6NWQlGkim23blrxs2OHdlvv3BleecXE0jB+J9blqLEyZoyK5rp18Omn8Msvaj127AgnncSmmCMY8NOLZBTH8fkh1/HH5BztfxkRUWFhtm6tx/fqBaNHm2AaRh1glmZjw7/yp6AAvv9eR1fExmrQZ+1a1h8+iiHZM8gjlC9Of4T+R3WDMdfr+XfcAevXay4maPeinj2t1Zth1BEmmo0J/4h5Xh589plajaGhutUuLmZN+cEMXX0tRSHhzLvsNfo9c1fla9x3n7aFW7RIhXPgQPNhGkYdYqLZmPBFzLdtgzlztKlGaKhanIWFrIg5gWG738M5SDnhFvpcfcVvr5GcrKJpGEZQMNFsTPgi5h9+qL5J53SaZGgo33uOZFjuB4R5ypgbP4be+cXwaEH1zTsMwwgaFghqTCQlaf13bq6KZnQ0lJWxlKMYXPwZkRQyv9uF9G6bqWJaU/MOwzCCholmY6JPH636KSzUVKHSUhaFnsTQ0s+IJZcFMWfQK3qL+jt37dIqoaKi3zbvMAwjaJhoNhZSU2HGDDj8cG2gUVzMV9l9OLnkE9p5drGg1QgOPK4DbNmird26dFFf58KFKrLW0s0w6gXzaTYW/MsmDz6YeSnCGQv+SjfZzBc9JtElXGDpUvVzxsdrww0fy5fDiBENt3bDaEGYaDYW/MomP/+pJ6O+Hk+Ptll8cfT9JHZIgoS+MH++WplbtkBUFLRpoyK6a5flYRpGPWGi2dD4ktm/+w6+/ZaZhUMY89M5HBL9M3MGTKX9zk0qmAkJamEWFKi45uVpOlJ4OJx8skXPDaOeMNEMFvua6eM7xpfMftBBfPBxKOcU/osjon7i8wOupO13v0D37loJBNpdfeFCtTJDQrRNXFYWXHVV/f8+w2ihWCAoGAQy0wcqhqF9/z3vzo5nXOErHBWayhetR9E2Ml9Tjlav1sqg7dshMRFOOEGtzfR03a5HRzfMbzSMFoqJZjAIZKYPqOjNns3rS3szPucZjg9dyudx5xAfXazVQNHRWnOem6upSNu26bY8P1+7ro8cqfmclqdpGPWGiWYwSEur2FL7qDrpMTUVUlN5OWsUF+5+igHyJbM8pxPrcrRBR2RkRSu4gQNVPBcv1iDQ8cdr1/W9CbJhGEHBfJrBwL8X5vbtusXOyFDL8cordbZPairP5Z3HFaVPcLJnDh9wFtHFuyHbO8PHOU1cP/JI3ZYPHw6bN+v1qzYnttG7hlFvmGgGgz594B//0K307t0qns7Bzp26JR8wgCfWDOPPJQ9yWshn/FfGEkkheELVx1lermWSJ52kggmVx+v6BNmHjd41jHrDtud1ja+yp0sXFcmcHN1Sl5RAhw4QG8tDc/ry5/wHGRX2CdPlbCLDy9UKDQ/Xx+DBGh2PiFABzcrSx5gxFc2Js7J++51hGEHHLM26xhcR37JFBTA+Xssc09OhTRse2HUFt2+axLiE2bwuEwnLLYLQaBVVgB49oH9/PTchoSJlyX+87uTJldOZbPSuYdQbJpp1TVqa+h4jI6FVKxXDyEhcbh73rDuPKXmTOL/DbKaNep/Q2VFQHKkC2bq1Hj9woFqn/frBlCnV3yM52UTSMBoIE826JDUVNmzQERXx8brl3rkTV1LKHZ4HeCDvei6OeIPnY+8h5Ps4Dej076/VQG3bqlCGh+t2e+LEhv41hmFUg4lmXeFLaO/SBVat0lzKvDxcWTmTd9/Nw2XXMynsJZ7u+TCesEg9Jz4e7rxTX/u225062XbbMBoxJpp1hX9CkL8bMQAAD8lJREFUe0kJfPEF5Tl5XO8e4Ynya7g25Gkej7sbiUpSn2dZGWzcCBMmwKhR1n3dMJoIFj2vK/wT2g87jPKOnbiKp3ii/Br+EvMcj0fejOTm6HEeD2zapMGh/Hzrvm4YTQizNKtjb802/L8LD9eqnaIi9WVu2wYFBZRty+CydbfwspvAbREPcV+vN5GfQqDMo/mXO3dqOlFpqSa6+3Iup083a9MwGjkmmlXx7zzk32xj8mT93vddWJgmqoN2G2rVCubPp7RNByZk/Is33NlMkSncFfEosiWy4vrl5ZqCFBKiCe+R3u+sqscwmgQNIpoisgnIA8qAUudc/4ZYR7UWpb9vEipbgb73CQnaQCM2Vj9buxaAkraJXLD5Qd4tP5v7PX/jNnkQyr1t3MrKVGhbt1bhBOjYUaPmYFU9htFEaEhLc7BzbkeD3b0mizI3V8UsJUWFLC4ODjmkwgr01X3n5FSIZk4ORfmlnJv+bz4sH8FD8f/gprDntI68pES34SEhKpgDB+rzggX6+SGHVFT1WJqRYTR6Wu72vCaLctUqtRxjY/VRUKACN3Cg1oH76r7j4vQ7oDCmPWevvZWZxUP4d+RfuTb6DQiLroiSt22rJZTXXgsrVqgADxqk2/PiYkszMowmREOJpgM+FxEHPOuce67eV+A3k+dX4uK0wYbIb4/ftUs/nz1bRbBbN1i5kj3lkYzOf5TZBYfxbNQNTOr8MeR6rcvQUBXe/v3VJ5qcDGPH1s/vMwwjKDRUytFJzrmjgBHANSIyoOoBIjJJRJaIyJLMzMy6X0FSkm6x/cnJ0VESAwboc26uPvfpA8uWabR86FA9dvlydvfuz+m732LOlt68eMiDTBq4VoW4TRs9Lzpax1X4BNMwjCaPOOcadgEiU4DdzrmpNR3Tv39/t2TJkrq9sb9PMy5OBTMrS6PgvmYZPmbN0me/Mbm56Xs4bea1LNySxCuvwAVH1HA9E0zDaLSIyNLaBqLrfXsuIq0Aj3Muz/v6FOCe+l4HycnVdwv68Ue4/XYVPI9HfZHZ2XDaab+eml0YyakfT2Tp1k689TaMGwdQw/VMMA2jWdEQPs2OwPuifsNQ4A3n3KcNsI7fdgtKTYVp0zTiHR6uQZpt23SbvXs3ALsKojjl1QtJ3d6Bd895j9Hjxtd8PcMwmh31LprOuQ1A3/q+b0BMn65jKTp0UJ8kaIQ8Px9WrCAzujvDZlzJ2p1teX/gY5x++8kNu17DMOodqz33Jy1NSyIj/Sp4IiMhNJT0Tkcy6MMb+HFnG2ac9xanP3qyWZWG0QJpuXma1ZGUBCtXapmjz9IsLGSLdGXIkqlsLmjHzNkwePCFDbtOwzAaDBNNf8aMgaVLYf16TTwH0na1Zsi259hOAp89uZ6T5r8Gr1bTyMMwjBaBbc/9SU6G++7Tap2SEjbmtWNgxjvsCOnI7Gc3ctLcezSq7l92ae3cDKNFYZZmVZKT4emnWffRGob8qRP5RSF8MeF5jv56ec2NPMzaNIwWg1ma1bD6g7UMPLcjhcUhzJswjaOjVmn5ZGFh5QOtnZthtDhatqVZTWu4FZ5khv6pCyLlpFw8jcM7ZAIJWm++fLk21/Bh7dwMo8XRckUzNRXuuEPzMouKYOVKls/dxbAfHibClTB3wisc0iGr4vh+/eCLL9SX6V8mae3cDKNF0XJF8+mnNUoeGgq5uSzJ6cUpuX+ndWQ2cy97g4NkA+BXfx4ZCcOGqS/TyiQNo8XSckVz7lxt97Z7NwvlRE4tnE4byWZe2JkcMOBOmOFtEGLNNwzD8KNlBoJSU3VbvmcPC2QgpxR8QAcyWBB3JgeEb9VGwSNHwvffw5tv6vPIkSaYhmG0UEtz+nTo0oW5P3blzJL/kuTZzBfhp9G5cBsceqgGfDZsgL59tbdmTg7MmAEHH2zCaRgtnJZpaaal8VnS5Zxe8j4Hys+kRI6gc2iG+jd799ZWcL6cTI+n4rVvuJphGC2WFmlpflw4jLM/O4dD4zYzO3wc7V0ehMVoWlFICMTHqy/TH8vJNAyDFmhpvv8+jHl3PMkJacw991nan3US9OihPTNPOkmDPf36VT8Kw3IyDaPF06IszbffhgsugGOO8fDp1D3EzY6AtBIdY1G1+cZU7/QNy8k0DMOPFiOar70GEybAiSfCzJkQE9MH/tCn+oNrGoVhQSDDaPG0CNF88cX/b+/eY7Oq7ziOvz9iAacY5yCEbW4oW1S2OGTVuHnZxbk48AKRMEacE+cQAsKIGCHIvCQkyJx4YWrAcRlz6pQZSueYykTjTBREKBcF8bJshMmM04GTDtrv/vj9CofyPH36tGvP76nfV9L0POc5p+fDj/bbc+n5Hrj66tC8aMWK8Oy0kvzRFc65Arr8Oc377w87iRdcALW1rSyYzjlXRJcumnffDePHw9ChsHx5uNbjnHPt0WWL5u23w+TJMHx4ODWZfeyPc861VZcsmrNmwfXXw8iR4Yp59+55J3LOdRVdqmiawU03wY03wuWXw4MPQlVV3qmcc11Jl7l6bgbTp8Ntt8GYMbBgQbi5xznn/p+6xJ6mGVx3XSiY48bBAw94wXTOdYyKL5qNjXDttTB3LkyaBPfeG3psOOdcR6jow/PGRrjmmrBnOXUqzJkDUt6pnHNdWcXukzU0wFVXhYI5Y4YXTOdc56jIPc39++GKK0JT9VtvhZkz807knPu4yGVPU9KFkrZK2i5pWjnr7tsHo0aFgjl7thdM51zn6vSiKakb8Evge8BA4AeSBrZm3fp6GDECli2DO+6AG27oyKTOOXe4PPY0zwS2m9mbZvZf4GHg0lIrffRRuCWypgbmzYMpUzo8p3POHSaPovkZ4G+Z13+P84pqbAwPg1y5EubPhwkTOjSfc84VleyFIEljgbEAPXqcxr59sGhRaCTsnHN5yWNPcwdwQub1Z+O8Q5jZfDOrNrPq+voqli71gumcy5/MrHM3KB0JbAPOJxTLNcBoM9vcwjr/BP4K9Abe7YycbZByNkg7X8rZwPO1R8rZAE42s17lrNDph+dmtl/SROBPQDdgYUsFM67TB0DSWjOr7oSYZUs5G6SdL+Vs4PnaI+VsEPKVu04u5zTN7AngiTy27Zxz7VGxt1E651weKq1ozs87QAtSzgZp50s5G3i+9kg5G7QhX6dfCHLOuUpWaXuazjmXq4oomu1p8NEZJL0taaOk9W25GtcBeRZK2iVpU2be8ZKekvR6/PzJhLLdLGlHHL/1kobklO0ESc9I2iJps6TJcX4qY1csXyrj11PSS5I2xHy3xPknSnox/vw+IqnTH3XYQrbFkt7KjN2gkl/MzJL+IPxZ0hvASUB3YAMwMO9czTK+DfTOO0cmz3nAYGBTZt4cYFqcngbcllC2m4GpCYxbP2BwnO5F+HvigQmNXbF8qYyfgGPidBXwInAW8DtgVJx/PzA+oWyLgRHlfK1K2NNsU4OPjzMzew54r9nsS4ElcXoJMKxTQ0VFsiXBzHaa2bo4vRt4ldAXIZWxK5YvCRbsiS+r4ocB3wYei/NzGb8WspWtEopm2Q0+cmDAk5JejvfMp6ivme2M0/8A+uYZpoCJkuri4Xsuh79ZkvoDpxP2SJIbu2b5IJHxk9RN0npgF/AU4SjxfTPbHxfJ7ee3eTYzaxq7WXHs5krqUerrVELRrATnmNlgQo/QCZLOyztQSywco6T0ZxP3AQOAQcBO4Bd5hpF0DLAM+KmZ/Tv7XgpjVyBfMuNnZg1mNojQU+JM4JS8sjTXPJukLwPTCRnPAI4HSnbprYSi2aoGH3kysx3x8y7gccI3S2rekdQPIH7elXOeA8zsnfgN3QgsIMfxk1RFKEgPmtnv4+xkxq5QvpTGr4mZvQ88A3wNOC72nIAEfn4z2S6MpzzMzOqBRbRi7CqhaK4BvhivwHUHRgE1OWc6QNLRkno1TQPfBTa1vFYuaoCmPlE/ApbnmOUQTQUpGk5O4ydJwK+AV83sjsxbSYxdsXwJjV8fScfF6aOACwjnXZ8BRsTFchm/Itley/wyFOFca+mxy/NqWxlXvoYQrhS+AczIO0+zbCcRruhvADankA94iHCYto9wDunHwKeAVcDrwNPA8QllWwpsBOoIBapfTtnOIRx61wHr48eQhMauWL5Uxu804JWYYxPwszj/JOAlYDvwKNAjoWx/jmO3CfgN8Qp7Sx9+R5BzzpWhEg7PnXMuGV40nXOuDF40nXOuDF40nXOuDF40nXOuDF40XbJi956pBeYPkzSwDV+vv6TRmddXSprX3pwFtrNaUrLPxXHt40XTtUvmTo/ONIzQ3ecwJfL0B0a38L5zJXnRdEVJmhn7mD4v6aGmvb64J3Vn7B06WdL5kl5R6Cm6sKnpgUKf0d5xulrS6jh9c1xutaQ3JU3KbHOGpG2SngdOLpDp68AlwM9j/8MBBfIsljQis05Td5vZwLlxvSlx3qclrVTolTmnwPYulPRo5vU3JdXG6fskrc32Zyyw/p7M9AhJi+N0H0nLJK2JH2e3/L/hUpHL0yhd+iSdAVwGfIXQRmsd8HJmke5mVi2pJ+FOmfPNbJukXwPjgTtLbOIU4FuEvpBbJd1HuGtjFKHxxJEFtomZvSCpBqg1s8di1gN54uvFRbY5jdB38qK43JVxW6cD9THHPWaW7ar1NDBf0tFm9iHwfUJ7Qgh3f70nqRuwStJpZlZX4t/d5C5grpk9L+lzhEdan9rKdV2OfE/TFXM2sNzM9lro3bii2fuPxM8nA2+Z2bb4egmh0XApfzCzejN7l9AAoy9wLvC4mf3HQveecnoMPFJ6kYJWmdkHZrYX2AJ8PvumhZZmK4GL46H/UA7eOz1S0jrC7XlfosgpgyK+A8yLrcpqgGNj9yKXON/TdG31YSuW2c/BX8w9m71Xn5luoP3fi9k8B7Yr6QhCx/9iWpPjYWAioXnyWjPbLelEYCpwhpn9K+7dNv83wqFt5LLvHwGcFYu1qyC+p+mK+Qth76pn3AO6qMhyW4H+kr4QX/8QeDZOvw18NU5f1optPgcMk3RU7Bx1cZHldhMO64vJbvcSwumF1qxXzLOER3T8hIOH5scSCvUHkvoSeqkW8o6kU2PxHp6Z/yRwbdMLtebZNC4JXjRdQWa2hnDYWAf8kdAJ5oMCy+0FxgCPStoINBKeAwNwC3BXvEDT0IptriMcZm+I21xTZNGHgevjxacBBd5fAHxD0gZCP8emvdA6oEHh4VpTCqxXLFcDUEsojLVx3gbCYflrwG8Jv2QKmRbXeYHQ3anJJKBaoWP4FmBca/O4fHmXI1eUpGPMbI+kTxD2AsfGwubcx5af03QtmR//iLwnsMQLpnO+p+mcc2Xxc5rOOVcGL5rOOVcGL5rOOVcGL5rOOVcGL5rOOVcGL5rOOVeG/wEXBGnsFCkaiAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [], + "needs_background": "light" + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aQikz3IPiyPf" + }, + "source": [ + "# **Testing**\n", + "The predictions of your model on testing set will be stored at `pred.csv`." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "O8cTuQjQQOon", + "outputId": "6bc5de07-4c5a-4e87-9ae3-d09f539c5f2c" + }, + "source": [ + "def save_pred(preds, file):\n", + " ''' Save predictions to specified file '''\n", + " print('Saving results to {}'.format(file))\n", + " with open(file, 'w') as fp:\n", + " writer = csv.writer(fp)\n", + " writer.writerow(['id', 'tested_positive'])\n", + " for i, p in enumerate(preds):\n", + " writer.writerow([i, p])\n", + "\n", + "preds = test(tt_set, model, device) # predict COVID-19 cases with your model\n", + "save_pred(preds, 'pred.csv') # save prediction file to pred.csv" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Saving results to pred.csv\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nfrVxqJanGpE" + }, + "source": [ + "# **Hints**\n", + "\n", + "## **Simple Baseline**\n", + "* Run sample code\n", + "\n", + "## **Medium Baseline**\n", + "* Feature selection: 40 states + 2 `tested_positive` (`TODO` in dataset)\n", + "\n", + "## **Strong Baseline**\n", + "* Feature selection (what other features are useful?)\n", + "* DNN architecture (layers? dimension? activation function?)\n", + "* Training (mini-batch? optimizer? learning rate?)\n", + "* L2 regularization\n", + "* There are some mistakes in the sample code, can you find them?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9tmCwXgpot3t" + }, + "source": [ + "# **Reference**\n", + "This code is completely written by Heng-Jui Chang @ NTUEE. \n", + "Copying or reusing this code is required to specify the original author. \n", + "\n", + "E.g. \n", + "Source: Heng-Jui Chang @ NTUEE (https://github.com/ga642381/ML2021-Spring/blob/main/HW01/HW01.ipynb)\n" + ] + } + ] +} \ No newline at end of file diff --git a/01 Introduction/Pytorch_Tutorial.ipynb b/01 Introduction/Pytorch_Tutorial.ipynb new file mode 100644 index 0000000..80e1be8 --- /dev/null +++ b/01 Introduction/Pytorch_Tutorial.ipynb @@ -0,0 +1,614 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Pytorch Tutorial", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "tHILOGjOQbsQ" + }, + "source": [ + "# **Pytorch Tutorial**\r\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "C1zA7GupxdJv" + }, + "source": [ + "import torch" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Eqj90EkWbWx" + }, + "source": [ + "**1. Pytorch Documentation Explanation with torch.max**\r\n", + "\r\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "JCXOg-iSQuk7" + }, + "source": [ + "x = torch.randn(4,5)\r\n", + "y = torch.randn(4,5)\r\n", + "z = torch.randn(4,5)\r\n", + "print(x)\r\n", + "print(y)\r\n", + "print(z)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "EEqa9GFoWF78" + }, + "source": [ + "# 1. max of entire tensor (torch.max(input) → Tensor)\r\n", + "m = torch.max(x)\r\n", + "print(m)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wffThGDyWKxJ" + }, + "source": [ + "# 2. max along a dimension (torch.max(input, dim, keepdim=False, *, out=None) → (Tensor, LongTensor))\r\n", + "m, idx = torch.max(x,0)\r\n", + "print(m)\r\n", + "print(idx)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "oKDQW3tIXKg-" + }, + "source": [ + "# 2-2\r\n", + "m, idx = torch.max(input=x,dim=0)\r\n", + "print(m)\r\n", + "print(idx)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "6QZ6WRLyX3De" + }, + "source": [ + "# 2-3\r\n", + "m, idx = torch.max(x,0,False)\r\n", + "print(m)\r\n", + "print(idx)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "nqGuctkKbUEn" + }, + "source": [ + "# 2-4\r\n", + "m, idx = torch.max(x,dim=0,keepdim=True)\r\n", + "print(m)\r\n", + "print(idx)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9OMzxuMlZPIu" + }, + "source": [ + "# 2-5\r\n", + "p = (m,idx)\r\n", + "torch.max(x,0,False,out=p)\r\n", + "print(p[0])\r\n", + "print(p[1])\r\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "uhd4TqGTbD2c" + }, + "source": [ + "# 2-6\r\n", + "p = (m,idx)\r\n", + "torch.max(x,0,False,p)\r\n", + "print(p[0])\r\n", + "print(p[1])" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wbxjUSOXxN0n" + }, + "source": [ + "# 2-7\r\n", + "m, idx = torch.max(x,True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iMwhGLlGWYaR" + }, + "source": [ + "# 3. max(choose max) operators on two tensors (torch.max(input, other, *, out=None) → Tensor)\r\n", + "t = torch.max(x,y)\r\n", + "print(t)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nFxRKu2Dedwb" + }, + "source": [ + "**2. Common errors**\r\n", + "\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KMcRyMxGwhul" + }, + "source": [ + "The following code blocks show some common errors while using the torch library. First, execute the code with error, and then execute the next code block to fix the error. You need to change the runtime to GPU.\r\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "eX-kKdi6ynFf" + }, + "source": [ + "import torch" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-muJ4KKreoP2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 363 + }, + "outputId": "c1d5c3a5-9540-4145-d80c-3cbca18a1deb" + }, + "source": [ + "# 1. different device error\r\n", + "model = torch.nn.Linear(5,1).to(\"cuda:0\")\r\n", + "x = torch.Tensor([1,2,3,4,5]).to(\"cpu\")\r\n", + "y = model(x)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1690\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1691\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1692\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1693\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1694\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "a54PqxJLe9-c", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "909d3693-236f-4419-f269-8fb443ef7534" + }, + "source": [ + "# 1. different device error (fixed)\r\n", + "x = torch.Tensor([1,2,3,4,5]).to(\"cuda:0\")\r\n", + "y = model(x)\r\n", + "print(y.shape)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "torch.Size([1])\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "n7OHtZwbi7Qw", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 201 + }, + "outputId": "2a7d2dd0-6498-4da0-9591-3554c1739046" + }, + "source": [ + "# 2. mismatched dimensions error\r\n", + "x = torch.randn(4,5)\r\n", + "y= torch.randn(5,4)\r\n", + "z = x + y" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qVynzvrskFCD", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "926dc01c-be6f-48e1-ad39-a5bcecebc513" + }, + "source": [ + "# 2. mismatched dimensions error (fixed)\r\n", + "y= y.transpose(0,1)\r\n", + "z = x + y\r\n", + "print(z.shape)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "torch.Size([4, 5])\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Hgzgb9gJANod", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 398 + }, + "outputId": "21b58850-b3f1-4f2a-db5d-cc45e47ccbea" + }, + "source": [ + "# 3. cuda out of memory error\n", + "import torch\n", + "import torchvision.models as models\n", + "resnet18 = models.resnet18().to(\"cuda:0\") # Neural Networks for Image Recognition\n", + "data = torch.randn(2048,3,244,244) # Create fake data (512 images)\n", + "out = resnet18(data.to(\"cuda:0\")) # Use Data as Input and Feed to Model\n", + "print(out.shape)\n" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mresnet18\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresnet18\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:0\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Neural Networks for Image Recognition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2048\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m244\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m244\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Create fake data (512 images)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresnet18\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda:0\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Use Data as Input and Feed to Model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py\u001b[0m in \u001b[0;36m_forward_impl\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;31m# See note [TorchScript super()]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 204\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 205\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaxpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrunning_mean\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrack_running_stats\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrunning_var\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrack_running_stats\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 136\u001b[0;31m self.weight, self.bias, bn_training, exponential_average_factor, self.eps)\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m 2056\u001b[0m return torch.batch_norm(\n\u001b[1;32m 2057\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2058\u001b[0;31m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2059\u001b[0m )\n\u001b[1;32m 2060\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 7.27 GiB (GPU 0; 14.76 GiB total capacity; 8.74 GiB already allocated; 4.42 GiB free; 9.42 GiB reserved in total by PyTorch)" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "VPksKnB_w343", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "fbee46ad-e63e-4bfc-8971-452895dd7a15" + }, + "source": [ + "# 3. cuda out of memory error (fixed)\n", + "for d in data:\n", + " out = resnet18(d.to(\"cuda:0\").unsqueeze(0))\n", + "print(out.shape)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "torch.Size([1, 1000])\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "vqszlxEE0Bk0", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 346 + }, + "outputId": "a698b34d-00a8-4067-ddc5-180cb4c8eeaa" + }, + "source": [ + "# 4. mismatched tensor type\n", + "import torch.nn as nn\n", + "L = nn.CrossEntropyLoss()\n", + "outs = torch.randn(5,5)\n", + "labels = torch.Tensor([1,2,3,4,0])\n", + "lossval = L(outs,labels) # Calculate CrossEntropyLoss between outs and labels" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mlossval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mL\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Calculate CrossEntropyLoss between outs and labels\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 960\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 961\u001b[0m return F.cross_entropy(input, target, weight=self.weight,\n\u001b[0;32m--> 962\u001b[0;31m ignore_index=self.ignore_index, reduction=self.reduction)\n\u001b[0m\u001b[1;32m 963\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mcross_entropy\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 2466\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2467\u001b[0m \u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2468\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2469\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2470\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mnll_loss\u001b[0;34m(input, target, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 2262\u001b[0m .format(input.size(0), target.size(0)))\n\u001b[1;32m 2263\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2264\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2265\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mdim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2266\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnll_loss2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_enum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: expected scalar type Long but found Float" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CZwgwup_1dgS", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "aaf1de76-7ef2-4ca4-b87d-8482a3117249" + }, + "source": [ + "# 4. mismatched tensor type (fixed)\n", + "labels = labels.long()\n", + "lossval = L(outs,labels)\n", + "print(lossval)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tensor(2.6215)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dSuNdA8F06dK" + }, + "source": [ + "**3. More on dataset and dataloader**\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "in84z_xu1rE6" + }, + "source": [ + "A dataset is a cluster of data in a organized way. A dataloader is a loader which can iterate through the data set." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "34zfh-c22Qqs" + }, + "source": [ + "Let a dataset be the English alphabets \"abcdefghijklmnopqrstuvwxyz\"" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "TaiHofty1qKA" + }, + "source": [ + "dataset = \"abcdefghijklmnopqrstuvwxyz\"" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h0jwhVa12h3a" + }, + "source": [ + "A simple dataloader could be implemented with the python code \"for\"" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bWC5Wwbv2egy" + }, + "source": [ + "for datapoint in dataset:\r\n", + " print(datapoint)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n33VKzkG2y2U" + }, + "source": [ + "When using the dataloader, we often like to shuffle the data. This is where torch.utils.data.DataLoader comes in handy. If each data is an index (0,1,2...) from the view of torch.utils.data.DataLoader, shuffling can simply be done by shuffling an index array. \r\n", + "\r\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9MXUUKQ65APf" + }, + "source": [ + "torch.utils.data.DataLoader will need two imformation to fulfill its role. First, it needs to know the length of the data. Second, once torch.utils.data.DataLoader outputs the index of the shuffling results, the dataset needs to return the corresponding data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BV5txsjK5j4j" + }, + "source": [ + "Therefore, torch.utils.data.Dataset provides the imformation by two functions, `__len__()` and `__getitem__()` to support torch.utils.data.Dataloader" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "A0IEkemJ5ajD" + }, + "source": [ + "import torch\r\n", + "import torch.utils.data \r\n", + "class ExampleDataset(torch.utils.data.Dataset):\r\n", + " def __init__(self):\r\n", + " self.data = \"abcdefghijklmnopqrstuvwxyz\"\r\n", + " \r\n", + " def __getitem__(self,idx): # if the index is idx, what will be the data?\r\n", + " return self.data[idx]\r\n", + " \r\n", + " def __len__(self): # What is the length of the dataset\r\n", + " return len(self.data)\r\n", + "\r\n", + "dataset1 = ExampleDataset() # create the dataset\r\n", + "dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)\r\n", + "for datapoint in dataloader:\r\n", + " print(datapoint)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nTt-ZTid9S2n" + }, + "source": [ + "A simple data augmentation technique can be done by changing the code in `__len__()` and `__getitem__()`. Suppose we want to double the length of the dataset by adding in the uppercase letters, using only the lowercase dataset, you can change the dataset to the following." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7Wn3BA2j-NXl" + }, + "source": [ + "import torch.utils.data \r\n", + "class ExampleDataset(torch.utils.data.Dataset):\r\n", + " def __init__(self):\r\n", + " self.data = \"abcdefghijklmnopqrstuvwxyz\"\r\n", + " \r\n", + " def __getitem__(self,idx): # if the index is idx, what will be the data?\r\n", + " if idx >= len(self.data): # if the index >= 26, return upper case letter\r\n", + " return self.data[idx%26].upper()\r\n", + " else: # if the index < 26, return lower case, return lower case letter\r\n", + " return self.data[idx]\r\n", + " \r\n", + " def __len__(self): # What is the length of the dataset\r\n", + " return 2 * len(self.data) # The length is now twice as large\r\n", + "\r\n", + "dataset1 = ExampleDataset() # create the dataset\r\n", + "dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)\r\n", + "for datapoint in dataloader:\r\n", + " print(datapoint)" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file