|
- import torch.nn.functional as F
- import torch.nn as nn
-
-
- class MLP(nn.Module):
- def __init__(self):
- super(MLP, self).__init__()
- self.fc1 = nn.Linear(3072, 256)
- self.fc2 = nn.Linear(256, 256)
- self.fc3 = nn.Linear(256, 10)
-
- def forward(self, x):
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- out = self.fc3(x)
- return out
-
-
- def mlp():
- return MLP()
|