From 4a1cf67b6e6ef8f29c541cfc92b476973e5cc90c Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 26 Dec 2023 23:35:47 +0800 Subject: [PATCH] [MNT] change LeNet5, finetune MNISTAdd parameters --- examples/mnist_add/main.py | 8 +++--- examples/mnist_add/mnist_add.ipynb | 10 +++---- examples/mnist_add/models/nn.py | 42 +++++++++++++----------------- 3 files changed, 27 insertions(+), 33 deletions(-) diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index fbb1bfc..cc6af7b 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -42,14 +42,14 @@ def main(): help="number of epochs in each learning loop iteration (default : 1)", ) parser.add_argument( - "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" + "--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)" ) parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") parser.add_argument( "--batch-size", type=int, default=32, help="base model batch size (default : 32)" ) parser.add_argument( - "--loops", type=int, default=1, help="number of loop iterations (default : 1)" + "--loops", type=int, default=2, help="number of loop iterations (default : 2)" ) parser.add_argument( "--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)" @@ -84,14 +84,14 @@ def main(): ### Building the Learning Part # Build necessary components for BasicNN cls = LeNet5(num_classes=10) - loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) + loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2) optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, - pct_start=0.2, + pct_start=0.15, epochs=args.loops, steps_per_epoch=int(1 / args.segment_size), ) diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index b43a81b..5e27a61 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -178,10 +178,10 @@ "outputs": [], "source": [ "cls = LeNet5(num_classes=10)\n", - "loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n", - "optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)\n", + "loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)\n", + "optimizer = RMSprop(cls.parameters(), lr=0.0003, alpha=0.9)\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", - "scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)\n", + "scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.0003, pct_start=0.15, total_steps=200)\n", "\n", "base_model = BasicNN(\n", " cls,\n", @@ -434,7 +434,7 @@ "log_dir = ABLLogger.get_current_instance().log_dir\n", "weights_dir = osp.join(log_dir, \"weights\")\n", "\n", - "bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n", + "bridge.train(train_data, loops=2, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] } @@ -455,7 +455,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/mnist_add/models/nn.py b/examples/mnist_add/models/nn.py index 182eb51..93f5cda 100644 --- a/examples/mnist_add/models/nn.py +++ b/examples/mnist_add/models/nn.py @@ -1,34 +1,28 @@ -import numpy as np -import torch from torch import nn class LeNet5(nn.Module): - def __init__(self, num_classes=10, image_size=(28, 28)): + def __init__(self, num_classes=10, image_size=(28, 28, 1)): super(LeNet5, self).__init__() - self.conv1 = nn.Sequential( - nn.Conv2d(1, 6, 3, padding=1), - nn.ReLU(), - nn.MaxPool2d(kernel_size=2, stride=2), + self.size = 16 * ((image_size[0] // 2 - 6) // 2) * ((image_size[1] // 2 - 6) // 2) + self.encoder = nn.Sequential( + nn.Conv2d(1, 6, 5), + nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12 + nn.ReLU(True), + nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8 + nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4 + nn.ReLU(True), ) - self.conv2 = nn.Sequential( - nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) + self.classifier = nn.Sequential( + nn.Linear(self.size, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, num_classes), ) - self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) - - feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 - num_features = 16 * feature_map_size[0] * feature_map_size[1] - - self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) - self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) - self.fc3 = nn.Linear(84, num_classes) def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.conv3(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = self.fc2(x) - x = self.fc3(x) + x = self.encoder(x) + x = x.view(-1, self.size) + x = self.classifier(x) return x