|
- from torch import nn
-
-
- class LeNet5(nn.Module):
- def __init__(self, num_classes=10, image_size=(28, 28, 1)):
- super(LeNet5, self).__init__()
- self.size = 16 * ((image_size[0] // 2 - 6) // 2) * ((image_size[1] // 2 - 6) // 2)
- self.encoder = nn.Sequential(
- nn.Conv2d(1, 6, 5),
- nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
- nn.ReLU(True),
- nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
- nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
- nn.ReLU(True),
- )
- self.classifier = nn.Sequential(
- nn.Linear(self.size, 120),
- nn.ReLU(),
- nn.Linear(120, 84),
- nn.ReLU(),
- nn.Linear(84, num_classes),
- )
-
- def forward(self, x):
- x = self.encoder(x)
- x = x.view(-1, self.size)
- x = self.classifier(x)
- return x
|