@@ -42,14 +42,14 @@ def main(): | |||||
help="number of epochs in each learning loop iteration (default : 1)", | help="number of epochs in each learning loop iteration (default : 1)", | ||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" | |||||
"--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)" | |||||
) | ) | ||||
parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") | parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") | ||||
parser.add_argument( | parser.add_argument( | ||||
"--batch-size", type=int, default=32, help="base model batch size (default : 32)" | "--batch-size", type=int, default=32, help="base model batch size (default : 32)" | ||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--loops", type=int, default=1, help="number of loop iterations (default : 1)" | |||||
"--loops", type=int, default=2, help="number of loop iterations (default : 2)" | |||||
) | ) | ||||
parser.add_argument( | parser.add_argument( | ||||
"--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)" | "--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)" | ||||
@@ -84,14 +84,14 @@ def main(): | |||||
### Building the Learning Part | ### Building the Learning Part | ||||
# Build necessary components for BasicNN | # Build necessary components for BasicNN | ||||
cls = LeNet5(num_classes=10) | cls = LeNet5(num_classes=10) | ||||
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) | |||||
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) | |||||
optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) | optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) | ||||
use_cuda = not args.no_cuda and torch.cuda.is_available() | use_cuda = not args.no_cuda and torch.cuda.is_available() | ||||
device = torch.device("cuda" if use_cuda else "cpu") | device = torch.device("cuda" if use_cuda else "cpu") | ||||
scheduler = lr_scheduler.OneCycleLR( | scheduler = lr_scheduler.OneCycleLR( | ||||
optimizer, | optimizer, | ||||
max_lr=args.lr, | max_lr=args.lr, | ||||
pct_start=0.2, | |||||
pct_start=0.15, | |||||
epochs=args.loops, | epochs=args.loops, | ||||
steps_per_epoch=int(1 / args.segment_size), | steps_per_epoch=int(1 / args.segment_size), | ||||
) | ) | ||||
@@ -178,10 +178,10 @@ | |||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"cls = LeNet5(num_classes=10)\n", | "cls = LeNet5(num_classes=10)\n", | ||||
"loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n", | |||||
"optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)\n", | |||||
"loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)\n", | |||||
"optimizer = RMSprop(cls.parameters(), lr=0.0003, alpha=0.9)\n", | |||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | ||||
"scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)\n", | |||||
"scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.0003, pct_start=0.15, total_steps=200)\n", | |||||
"\n", | "\n", | ||||
"base_model = BasicNN(\n", | "base_model = BasicNN(\n", | ||||
" cls,\n", | " cls,\n", | ||||
@@ -434,7 +434,7 @@ | |||||
"log_dir = ABLLogger.get_current_instance().log_dir\n", | "log_dir = ABLLogger.get_current_instance().log_dir\n", | ||||
"weights_dir = osp.join(log_dir, \"weights\")\n", | "weights_dir = osp.join(log_dir, \"weights\")\n", | ||||
"\n", | "\n", | ||||
"bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n", | |||||
"bridge.train(train_data, loops=2, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n", | |||||
"bridge.test(test_data)" | "bridge.test(test_data)" | ||||
] | ] | ||||
} | } | ||||
@@ -455,7 +455,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.8.13" | |||||
"version": "3.8.18" | |||||
}, | }, | ||||
"orig_nbformat": 4, | "orig_nbformat": 4, | ||||
"vscode": { | "vscode": { | ||||
@@ -1,34 +1,28 @@ | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | from torch import nn | ||||
class LeNet5(nn.Module): | class LeNet5(nn.Module): | ||||
def __init__(self, num_classes=10, image_size=(28, 28)): | |||||
def __init__(self, num_classes=10, image_size=(28, 28, 1)): | |||||
super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
self.conv1 = nn.Sequential( | |||||
nn.Conv2d(1, 6, 3, padding=1), | |||||
nn.ReLU(), | |||||
nn.MaxPool2d(kernel_size=2, stride=2), | |||||
self.size = 16 * ((image_size[0] // 2 - 6) // 2) * ((image_size[1] // 2 - 6) // 2) | |||||
self.encoder = nn.Sequential( | |||||
nn.Conv2d(1, 6, 5), | |||||
nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12 | |||||
nn.ReLU(True), | |||||
nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8 | |||||
nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4 | |||||
nn.ReLU(True), | |||||
) | ) | ||||
self.conv2 = nn.Sequential( | |||||
nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.classifier = nn.Sequential( | |||||
nn.Linear(self.size, 120), | |||||
nn.ReLU(), | |||||
nn.Linear(120, 84), | |||||
nn.ReLU(), | |||||
nn.Linear(84, num_classes), | |||||
) | ) | ||||
self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) | |||||
feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 | |||||
num_features = 16 * feature_map_size[0] * feature_map_size[1] | |||||
self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||||
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||||
self.fc3 = nn.Linear(84, num_classes) | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.conv1(x) | |||||
x = self.conv2(x) | |||||
x = self.conv3(x) | |||||
x = torch.flatten(x, 1) | |||||
x = self.fc1(x) | |||||
x = self.fc2(x) | |||||
x = self.fc3(x) | |||||
x = self.encoder(x) | |||||
x = x.view(-1, self.size) | |||||
x = self.classifier(x) | |||||
return x | return x |