Browse Source

[MNT] change LeNet5, finetune MNISTAdd parameters

pull/1/head
Gao Enhao 1 year ago
parent
commit
4a1cf67b6e
3 changed files with 27 additions and 33 deletions
  1. +4
    -4
      examples/mnist_add/main.py
  2. +5
    -5
      examples/mnist_add/mnist_add.ipynb
  3. +18
    -24
      examples/mnist_add/models/nn.py

+ 4
- 4
examples/mnist_add/main.py View File

@@ -42,14 +42,14 @@ def main():
help="number of epochs in each learning loop iteration (default : 1)", help="number of epochs in each learning loop iteration (default : 1)",
) )
parser.add_argument( parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
"--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)"
) )
parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)") parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
parser.add_argument( parser.add_argument(
"--batch-size", type=int, default=32, help="base model batch size (default : 32)" "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
) )
parser.add_argument( parser.add_argument(
"--loops", type=int, default=1, help="number of loop iterations (default : 1)"
"--loops", type=int, default=2, help="number of loop iterations (default : 2)"
) )
parser.add_argument( parser.add_argument(
"--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)" "--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)"
@@ -84,14 +84,14 @@ def main():
### Building the Learning Part ### Building the Learning Part
# Build necessary components for BasicNN # Build necessary components for BasicNN
cls = LeNet5(num_classes=10) cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha)
use_cuda = not args.no_cuda and torch.cuda.is_available() use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") device = torch.device("cuda" if use_cuda else "cpu")
scheduler = lr_scheduler.OneCycleLR( scheduler = lr_scheduler.OneCycleLR(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
pct_start=0.2,
pct_start=0.15,
epochs=args.loops, epochs=args.loops,
steps_per_epoch=int(1 / args.segment_size), steps_per_epoch=int(1 / args.segment_size),
) )


+ 5
- 5
examples/mnist_add/mnist_add.ipynb View File

@@ -178,10 +178,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"cls = LeNet5(num_classes=10)\n", "cls = LeNet5(num_classes=10)\n",
"loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
"optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)\n",
"loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)\n",
"optimizer = RMSprop(cls.parameters(), lr=0.0003, alpha=0.9)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)\n",
"scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.0003, pct_start=0.15, total_steps=200)\n",
"\n", "\n",
"base_model = BasicNN(\n", "base_model = BasicNN(\n",
" cls,\n", " cls,\n",
@@ -434,7 +434,7 @@
"log_dir = ABLLogger.get_current_instance().log_dir\n", "log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")\n", "weights_dir = osp.join(log_dir, \"weights\")\n",
"\n", "\n",
"bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n",
"bridge.train(train_data, loops=2, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)" "bridge.test(test_data)"
] ]
} }
@@ -455,7 +455,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.18"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {


+ 18
- 24
examples/mnist_add/models/nn.py View File

@@ -1,34 +1,28 @@
import numpy as np
import torch
from torch import nn from torch import nn




class LeNet5(nn.Module): class LeNet5(nn.Module):
def __init__(self, num_classes=10, image_size=(28, 28)):
def __init__(self, num_classes=10, image_size=(28, 28, 1)):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 6, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
self.size = 16 * ((image_size[0] // 2 - 6) // 2) * ((image_size[1] // 2 - 6) // 2)
self.encoder = nn.Sequential(
nn.Conv2d(1, 6, 5),
nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
nn.ReLU(True),
nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
nn.ReLU(True),
) )
self.conv2 = nn.Sequential(
nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = nn.Sequential(
nn.Linear(self.size, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes),
) )
self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU())

feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2
num_features = 16 * feature_map_size[0] * feature_map_size[1]

self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU())
self.fc3 = nn.Linear(84, num_classes)


def forward(self, x): def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.encoder(x)
x = x.view(-1, self.size)
x = self.classifier(x)
return x return x

Loading…
Cancel
Save