|
|
@@ -30,8 +30,11 @@ |
|
|
|
"id": "eNSV4QGHS1I1" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# **Homework 2-2 Hessian Matrix**\r\n", |
|
|
|
"\r\n" |
|
|
|
"# **Homework 2-2 Hessian Matrix**\n", |
|
|
|
"\n", |
|
|
|
"* Slides: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW02/HW02.pdf\n", |
|
|
|
"* Video (Chinese): https://youtu.be/PdjXnQbu2zo\n", |
|
|
|
"* Video (English): https://youtu.be/ESRr-VCykBs\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -171,7 +174,7 @@ |
|
|
|
"id": "ZFGBCIFmVLS_" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"### Import Libraries\r\n" |
|
|
|
"### Import Libraries\n" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -180,16 +183,16 @@ |
|
|
|
"id": "_-vjBvH0uqA-" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"import numpy as np\r\n", |
|
|
|
"from math import pi\r\n", |
|
|
|
"from collections import defaultdict\r\n", |
|
|
|
"from autograd_lib import autograd_lib\r\n", |
|
|
|
"\r\n", |
|
|
|
"import torch\r\n", |
|
|
|
"import torch.nn as nn\r\n", |
|
|
|
"from torch.utils.data import DataLoader, Dataset\r\n", |
|
|
|
"\r\n", |
|
|
|
"import warnings\r\n", |
|
|
|
"import numpy as np\n", |
|
|
|
"from math import pi\n", |
|
|
|
"from collections import defaultdict\n", |
|
|
|
"from autograd_lib import autograd_lib\n", |
|
|
|
"\n", |
|
|
|
"import torch\n", |
|
|
|
"import torch.nn as nn\n", |
|
|
|
"from torch.utils.data import DataLoader, Dataset\n", |
|
|
|
"\n", |
|
|
|
"import warnings\n", |
|
|
|
"warnings.filterwarnings(\"ignore\")" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
@@ -212,17 +215,17 @@ |
|
|
|
"id": "uvdOpR9lVaJQ" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"class MathRegressor(nn.Module):\r\n", |
|
|
|
" def __init__(self, num_hidden=128):\r\n", |
|
|
|
" super().__init__()\r\n", |
|
|
|
" self.regressor = nn.Sequential(\r\n", |
|
|
|
" nn.Linear(1, num_hidden),\r\n", |
|
|
|
" nn.ReLU(),\r\n", |
|
|
|
" nn.Linear(num_hidden, 1)\r\n", |
|
|
|
" )\r\n", |
|
|
|
"\r\n", |
|
|
|
" def forward(self, x):\r\n", |
|
|
|
" x = self.regressor(x)\r\n", |
|
|
|
"class MathRegressor(nn.Module):\n", |
|
|
|
" def __init__(self, num_hidden=128):\n", |
|
|
|
" super().__init__()\n", |
|
|
|
" self.regressor = nn.Sequential(\n", |
|
|
|
" nn.Linear(1, num_hidden),\n", |
|
|
|
" nn.ReLU(),\n", |
|
|
|
" nn.Linear(num_hidden, 1)\n", |
|
|
|
" )\n", |
|
|
|
"\n", |
|
|
|
" def forward(self, x):\n", |
|
|
|
" x = self.regressor(x)\n", |
|
|
|
" return x" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
@@ -297,12 +300,12 @@ |
|
|
|
"id": "OSU8vnXEbY6q" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# load checkpoint and data corresponding to the key\r\n", |
|
|
|
"model = MathRegressor()\r\n", |
|
|
|
"autograd_lib.register(model)\r\n", |
|
|
|
"\r\n", |
|
|
|
"data = torch.load('data.pth')[key]\r\n", |
|
|
|
"model.load_state_dict(data['model'])\r\n", |
|
|
|
"# load checkpoint and data corresponding to the key\n", |
|
|
|
"model = MathRegressor()\n", |
|
|
|
"autograd_lib.register(model)\n", |
|
|
|
"\n", |
|
|
|
"data = torch.load('data.pth')[key]\n", |
|
|
|
"model.load_state_dict(data['model'])\n", |
|
|
|
"train, target = data['data']" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
@@ -490,13 +493,13 @@ |
|
|
|
"id": "1X-2uxwTcB9u" |
|
|
|
}, |
|
|
|
"source": [ |
|
|
|
"# the main function to compute gradient norm and minimum ratio\r\n", |
|
|
|
"def main(model, train, target):\r\n", |
|
|
|
" criterion = nn.MSELoss()\r\n", |
|
|
|
"\r\n", |
|
|
|
" gradient_norm = compute_gradient_norm(model, criterion, train, target)\r\n", |
|
|
|
" minimum_ratio = compute_minimum_ratio(model, criterion, train, target)\r\n", |
|
|
|
"\r\n", |
|
|
|
"# the main function to compute gradient norm and minimum ratio\n", |
|
|
|
"def main(model, train, target):\n", |
|
|
|
" criterion = nn.MSELoss()\n", |
|
|
|
"\n", |
|
|
|
" gradient_norm = compute_gradient_norm(model, criterion, train, target)\n", |
|
|
|
" minimum_ratio = compute_minimum_ratio(model, criterion, train, target)\n", |
|
|
|
"\n", |
|
|
|
" print('gradient norm: {}, minimum ratio: {}'.format(gradient_norm, minimum_ratio))" |
|
|
|
], |
|
|
|
"execution_count": null, |
|
|
|