from torch import nn class ConvModel(nn.Module): def __init__( self, channel, n_random_features, net_width=64, net_depth=3, net_act="relu", net_norm="batchnorm", net_pooling="avgpooling", im_size=(32, 32), ): super().__init__() self.features, shape_feat = self._make_layers( channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size ) num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2] self.classifier = nn.Linear(num_feat, n_random_features) def forward(self, x): out = self.features(x) out = out.reshape(out.size(0), -1) out = self.classifier(out) return out def _get_activation(self, net_act): if net_act == "sigmoid": return nn.Sigmoid() elif net_act == "relu": return nn.ReLU(inplace=True) elif net_act == "leakyrelu": return nn.LeakyReLU(negative_slope=0.01) elif net_act == "gelu": return nn.SiLU() else: raise Exception("unknown activation function: %s" % net_act) def _get_pooling(self, net_pooling): if net_pooling == "maxpooling": return nn.MaxPool2d(kernel_size=2, stride=2) elif net_pooling == "avgpooling": return nn.AvgPool2d(kernel_size=2, stride=2) elif net_pooling == "none": return None else: raise Exception("unknown net_pooling: %s" % net_pooling) def _get_normlayer(self, net_norm, shape_feat): if net_norm == "batchnorm": return nn.BatchNorm2d(shape_feat[0], affine=True) elif net_norm == "layernorm": return nn.LayerNorm(shape_feat, elementwise_affine=True) elif net_norm == "instancenorm": return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) elif net_norm == "groupnorm": return nn.GroupNorm(4, shape_feat[0], affine=True) elif net_norm == "none": return None else: raise Exception("unknown net_norm: %s" % net_norm) def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): layers = [] in_channels = channel shape_feat = [in_channels, im_size[0], im_size[1]] for d in range(net_depth): layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding="same")] shape_feat[0] = net_width if net_norm != "none": layers += [self._get_normlayer(net_norm, shape_feat)] layers += [self._get_activation(net_act)] in_channels = net_width if net_pooling != "none": layers += [self._get_pooling(net_pooling)] shape_feat[1] //= 2 shape_feat[2] //= 2 return nn.Sequential(*layers), shape_feat