You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MLP.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import torch.nn as nn
  2. from torch.nn import Mish
  3. class MLPLayer(nn.Module):
  4. def __init__(self, dim_in, dim_out, res_coef=0.0, dropout_p=0.1):
  5. super().__init__()
  6. self.linear = nn.Linear(dim_in, dim_out)
  7. self.res_coef = res_coef
  8. self.activation = Mish()
  9. self.dropout = nn.Dropout(dropout_p)
  10. self.ln = nn.LayerNorm(dim_out)
  11. def forward(self, x):
  12. y = self.linear(x)
  13. y = self.activation(y)
  14. y = self.dropout(y)
  15. if self.res_coef == 0:
  16. return self.ln(y)
  17. else:
  18. return self.ln(self.res_coef * x + y)
  19. class MLP(nn.Module):
  20. def __init__(self, dim_in, dim, res_coef=0.5, dropout_p=0.1, n_layers=10):
  21. super().__init__()
  22. self.mlp = nn.ModuleList()
  23. self.first_linear = MLPLayer(dim_in, dim)
  24. self.n_layers = n_layers
  25. for i in range(n_layers):
  26. self.mlp.append(MLPLayer(dim, dim, res_coef, dropout_p))
  27. self.final = nn.Linear(dim, 1)
  28. self.sigmoid = nn.Sigmoid()
  29. def forward(self, x):
  30. x = self.first_linear(x)
  31. for layer in self.mlp:
  32. x = layer(x)
  33. x = self.sigmoid(self.final(x))
  34. return x.squeeze()

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)