|
- from torch import nn
-
-
- class ConvModel(nn.Module):
- def __init__(
- self,
- channel,
- n_random_features,
- net_width=64,
- net_depth=3,
- net_act="relu",
- net_norm="batchnorm",
- net_pooling="avgpooling",
- im_size=(32, 32),
- ):
- super().__init__()
- self.features, shape_feat = self._make_layers(
- channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size
- )
- num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
- self.classifier = nn.Linear(num_feat, n_random_features)
-
- def forward(self, x):
- out = self.features(x)
- out = out.reshape(out.size(0), -1)
- out = self.classifier(out)
- return out
-
- def _get_activation(self, net_act):
- if net_act == "sigmoid":
- return nn.Sigmoid()
- elif net_act == "relu":
- return nn.ReLU(inplace=True)
- elif net_act == "leakyrelu":
- return nn.LeakyReLU(negative_slope=0.01)
- elif net_act == "gelu":
- return nn.SiLU()
- else:
- raise Exception("unknown activation function: %s" % net_act)
-
- def _get_pooling(self, net_pooling):
- if net_pooling == "maxpooling":
- return nn.MaxPool2d(kernel_size=2, stride=2)
- elif net_pooling == "avgpooling":
- return nn.AvgPool2d(kernel_size=2, stride=2)
- elif net_pooling == "none":
- return None
- else:
- raise Exception("unknown net_pooling: %s" % net_pooling)
-
- def _get_normlayer(self, net_norm, shape_feat):
- if net_norm == "batchnorm":
- return nn.BatchNorm2d(shape_feat[0], affine=True)
- elif net_norm == "layernorm":
- return nn.LayerNorm(shape_feat, elementwise_affine=True)
- elif net_norm == "instancenorm":
- return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
- elif net_norm == "groupnorm":
- return nn.GroupNorm(4, shape_feat[0], affine=True)
- elif net_norm == "none":
- return None
- else:
- raise Exception("unknown net_norm: %s" % net_norm)
-
- def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
- layers = []
- in_channels = channel
- shape_feat = [in_channels, im_size[0], im_size[1]]
- for d in range(net_depth):
- layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding="same")]
-
- shape_feat[0] = net_width
- if net_norm != "none":
- layers += [self._get_normlayer(net_norm, shape_feat)]
- layers += [self._get_activation(net_act)]
- in_channels = net_width
- if net_pooling != "none":
- layers += [self._get_pooling(net_pooling)]
- shape_feat[1] //= 2
- shape_feat[2] //= 2
-
- return nn.Sequential(*layers), shape_feat
|