You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

feedforward.py 933 B

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. """
  2. FeedForward class.
  3. """
  4. import torch
  5. import torch.nn as nn
  6. class FeedForward(nn.Module):
  7. """
  8. Positional feed forward layer.
  9. """
  10. def __init__(self, hidden_dim, inner_dim, dropout):
  11. super(FeedForward, self).__init__()
  12. self.hidden_dim = hidden_dim
  13. self.inner_dim = inner_dim
  14. self.linear_hidden = nn.Sequential(
  15. nn.Linear(hidden_dim, inner_dim), nn.GELU())
  16. self.linear_out = nn.Linear(inner_dim, hidden_dim)
  17. self.dropout_layer = nn.Dropout(p=dropout)
  18. return
  19. def forward(self, x):
  20. out = self.linear_hidden(x)
  21. out = self.dropout_layer(out)
  22. out = self.linear_out(out)
  23. return out
  24. def main():
  25. import numpy as np
  26. model = FeedForward(10, 20, 0.5)
  27. inp = np.random.rand(2, 3, 10).astype('float32')
  28. inp = torch.tensor(inp)
  29. out = model(inp)
  30. print(out)
  31. if __name__ == '__main__':
  32. main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展