You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from torch import nn
  2. class ConvModel(nn.Module):
  3. def __init__(
  4. self,
  5. channel,
  6. n_random_features,
  7. net_width=64,
  8. net_depth=3,
  9. net_act="relu",
  10. net_norm="batchnorm",
  11. net_pooling="avgpooling",
  12. im_size=(32, 32),
  13. ):
  14. super().__init__()
  15. self.features, shape_feat = self._make_layers(
  16. channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size
  17. )
  18. num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
  19. self.classifier = nn.Linear(num_feat, n_random_features)
  20. def forward(self, x):
  21. out = self.features(x)
  22. out = out.reshape(out.size(0), -1)
  23. out = self.classifier(out)
  24. return out
  25. def _get_activation(self, net_act):
  26. if net_act == "sigmoid":
  27. return nn.Sigmoid()
  28. elif net_act == "relu":
  29. return nn.ReLU(inplace=True)
  30. elif net_act == "leakyrelu":
  31. return nn.LeakyReLU(negative_slope=0.01)
  32. elif net_act == "gelu":
  33. return nn.SiLU()
  34. else:
  35. raise Exception("unknown activation function: %s" % net_act)
  36. def _get_pooling(self, net_pooling):
  37. if net_pooling == "maxpooling":
  38. return nn.MaxPool2d(kernel_size=2, stride=2)
  39. elif net_pooling == "avgpooling":
  40. return nn.AvgPool2d(kernel_size=2, stride=2)
  41. elif net_pooling == "none":
  42. return None
  43. else:
  44. raise Exception("unknown net_pooling: %s" % net_pooling)
  45. def _get_normlayer(self, net_norm, shape_feat):
  46. if net_norm == "batchnorm":
  47. return nn.BatchNorm2d(shape_feat[0], affine=True)
  48. elif net_norm == "layernorm":
  49. return nn.LayerNorm(shape_feat, elementwise_affine=True)
  50. elif net_norm == "instancenorm":
  51. return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
  52. elif net_norm == "groupnorm":
  53. return nn.GroupNorm(4, shape_feat[0], affine=True)
  54. elif net_norm == "none":
  55. return None
  56. else:
  57. raise Exception("unknown net_norm: %s" % net_norm)
  58. def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
  59. layers = []
  60. in_channels = channel
  61. shape_feat = [in_channels, im_size[0], im_size[1]]
  62. for d in range(net_depth):
  63. layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding="same")]
  64. shape_feat[0] = net_width
  65. if net_norm != "none":
  66. layers += [self._get_normlayer(net_norm, shape_feat)]
  67. layers += [self._get_activation(net_act)]
  68. in_channels = net_width
  69. if net_pooling != "none":
  70. layers += [self._get_pooling(net_pooling)]
  71. shape_feat[1] //= 2
  72. shape_feat[2] //= 2
  73. return nn.Sequential(*layers), shape_feat