You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. class Linear(nn.Module):
  6. def __init__(self, input_feature=256, num_classes=10):
  7. super().__init__()
  8. self.linear_1 = nn.Linear(input_feature, 128)
  9. self.dropout_1 = nn.Dropout(p=0.5)
  10. self.linear_2 = nn.Linear(128, 128)
  11. self.dropout_2 = nn.Dropout(p=0.5)
  12. self.linear_3 = nn.Linear(128, num_classes)
  13. def forward(self, x):
  14. out1 = F.relu(self.dropout_1(self.linear_1(x)))
  15. out2 = F.relu(self.dropout_2(self.linear_2(out1)))
  16. out = self.linear_3(out2)
  17. return out
  18. class OriginModel(nn.Module):
  19. def __init__(self, last_layer_feature=256):
  20. super().__init__()
  21. self.linear_1 = nn.Linear(last_layer_feature, 128)
  22. self.linear_2 = nn.Linear(128, 128)
  23. self.linear_3 = nn.Linear(128, 10)
  24. def forward(self, x):
  25. out = F.relu(self.linear_1(x))
  26. out = F.relu(self.linear_2(out))
  27. out = self.linear_3(out)
  28. return out
  29. class ConvModel(nn.Module):
  30. def __init__(
  31. self,
  32. channel,
  33. n_random_features,
  34. net_width=64,
  35. net_depth=3,
  36. net_act="relu",
  37. net_norm="batchnorm",
  38. net_pooling="avgpooling",
  39. im_size=(32, 32),
  40. ):
  41. super().__init__()
  42. # print('Building Conv Model')
  43. self.features, shape_feat = self._make_layers(
  44. channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size
  45. )
  46. num_feat = shape_feat[0] * shape_feat[1] * shape_feat[2]
  47. self.classifier = GaussianLinear(num_feat, n_random_features)
  48. def forward(self, x):
  49. out = self.features(x)
  50. out = out.reshape(out.size(0), -1)
  51. out = self.classifier(out)
  52. return out
  53. def _get_activation(self, net_act):
  54. if net_act == "sigmoid":
  55. return nn.Sigmoid()
  56. elif net_act == "relu":
  57. return nn.ReLU(inplace=True)
  58. elif net_act == "leakyrelu":
  59. return nn.LeakyReLU(negative_slope=0.01)
  60. elif net_act == "gelu":
  61. return nn.SiLU()
  62. else:
  63. exit("unknown activation function: %s" % net_act)
  64. def _get_pooling(self, net_pooling):
  65. if net_pooling == "maxpooling":
  66. return nn.MaxPool2d(kernel_size=2, stride=2)
  67. elif net_pooling == "avgpooling":
  68. return nn.AvgPool2d(kernel_size=2, stride=2)
  69. elif net_pooling == "none":
  70. return None
  71. else:
  72. exit("unknown net_pooling: %s" % net_pooling)
  73. def _get_normlayer(self, net_norm, shape_feat):
  74. # shape_feat = (c*h*w)
  75. if net_norm == "batchnorm":
  76. return nn.BatchNorm2d(shape_feat[0], affine=True)
  77. elif net_norm == "layernorm":
  78. return nn.LayerNorm(shape_feat, elementwise_affine=True)
  79. elif net_norm == "instancenorm":
  80. return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True)
  81. elif net_norm == "groupnorm":
  82. return nn.GroupNorm(4, shape_feat[0], affine=True)
  83. elif net_norm == "none":
  84. return None
  85. else:
  86. exit("unknown net_norm: %s" % net_norm)
  87. def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size):
  88. layers = []
  89. in_channels = channel
  90. # if im_size[0] == 28:
  91. # im_size = (32, 32)
  92. shape_feat = [in_channels, im_size[0], im_size[1]]
  93. for d in range(net_depth):
  94. # print(shape_feat)
  95. layers += [Conv2d_gaussian(in_channels, net_width, kernel_size=3, padding=1)]
  96. # layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding='same')]
  97. shape_feat[0] = net_width
  98. if net_norm != "none":
  99. layers += [self._get_normlayer(net_norm, shape_feat)]
  100. layers += [self._get_activation(net_act)]
  101. in_channels = net_width
  102. if net_pooling != "none":
  103. layers += [self._get_pooling(net_pooling)]
  104. shape_feat[1] //= 2
  105. shape_feat[2] //= 2
  106. return nn.Sequential(*layers), shape_feat
  107. class Conv2d_gaussian(torch.nn.Conv2d):
  108. def reset_parameters(self) -> None:
  109. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  110. # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size)
  111. # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573
  112. # torch.nn.init.kaiming_normal_(self.weight, a= math.sqrt(5))
  113. # W has shape out, in, h, w
  114. torch.nn.init.normal_(
  115. self.weight, 0, np.sqrt(2) / np.sqrt(self.weight.shape[1] * self.weight.shape[2] * self.weight.shape[3])
  116. )
  117. if self.bias is not None:
  118. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
  119. # print(fan_in)
  120. if fan_in != 0:
  121. # bound = 0 * 1 / math.sqrt(fan_in)
  122. # torch.nn.init.uniform_(self.bias, -bound, bound)
  123. # torch.nn.init.uniform_(self.bias, -bound, bound)
  124. torch.nn.init.normal_(self.bias, 0, 0.1)
  125. class GaussianLinear(torch.nn.Module):
  126. __constants__ = ["in_features", "out_features"]
  127. in_features: int
  128. out_features: int
  129. weight: torch.Tensor
  130. def __init__(
  131. self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, funny=False
  132. ) -> None:
  133. factory_kwargs = {"device": device, "dtype": dtype}
  134. super(GaussianLinear, self).__init__()
  135. self.funny = funny
  136. self.in_features = in_features
  137. self.out_features = out_features
  138. self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
  139. if bias:
  140. self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
  141. else:
  142. self.register_parameter("bias", None)
  143. self.reset_parameters()
  144. def reset_parameters(self) -> None:
  145. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  146. # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
  147. # https://github.com/pytorch/pytorch/issues/57109
  148. # torch.nn.init.kaiming_normal_(self.weight, a=1 * np.sqrt(5))
  149. torch.nn.init.normal_(self.weight, 0, np.sqrt(2) / np.sqrt(self.in_features))
  150. # torch.nn.init.normal_(self.weight, 0, 3/np.sqrt(self.in_features))
  151. if self.bias is not None:
  152. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
  153. bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0
  154. # torch.nn.init.uniform_(self.bias, -bound, bound)
  155. torch.nn.init.normal_(self.bias, 0, 0.1)
  156. def forward(self, input: torch.Tensor) -> torch.Tensor:
  157. return torch.nn.functional.linear(input, self.weight, self.bias)
  158. def extra_repr(self) -> str:
  159. return "in_features={}, out_features={}, bias={}".format(
  160. self.in_features, self.out_features, self.bias is not None
  161. )