You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformer_block.py 2.0 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. """
  2. TransformerBlock class.
  3. """
  4. import torch
  5. import torch.nn as nn
  6. from maas_lib.models.nlp.space.modules.feedforward import FeedForward
  7. from maas_lib.models.nlp.space.modules.multihead_attention import \
  8. MultiheadAttention
  9. class TransformerBlock(nn.Module):
  10. """
  11. Transformer block module.
  12. """
  13. def __init__(self, hidden_dim, num_heads, dropout, attn_dropout,
  14. ff_dropout):
  15. super(TransformerBlock, self).__init__()
  16. self.attn = MultiheadAttention(
  17. hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout)
  18. self.attn_norm = nn.LayerNorm(
  19. normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True)
  20. self.ff = FeedForward(
  21. hidden_dim=hidden_dim,
  22. inner_dim=4 * hidden_dim,
  23. dropout=ff_dropout)
  24. self.ff_norm = nn.LayerNorm(
  25. normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True)
  26. self.dropout_layer = nn.Dropout(p=dropout)
  27. return
  28. def forward(self, inp, mask=None, cache=None):
  29. """
  30. Forward process on one transformer layer.
  31. @param : x
  32. @type : Variable(shape: [batch_size, seq_len, hidden_size])
  33. @param : memory
  34. @type : Variable(shape: [batch_size, seq_len, hidden_size])
  35. @param : mask
  36. @param : cache
  37. """
  38. attn_out = self.attn(inp, mask, cache)
  39. attn_out = self.dropout_layer(attn_out)
  40. attn_out = self.attn_norm(attn_out + inp)
  41. ff_out = self.ff(attn_out)
  42. ff_out = self.dropout_layer(ff_out)
  43. ff_out = self.ff_norm(ff_out + attn_out)
  44. return ff_out
  45. def main():
  46. import numpy as np
  47. model = TransformerBlock(10, 2, 0.5, 0.5, 0.5)
  48. inp = np.random.rand(2, 3, 10).astype('float32')
  49. inp = torch.tensor(inp)
  50. mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32')
  51. mask = torch.tensor(mask)
  52. out = model(inp, mask=mask, cache=None)
  53. print(out)
  54. if __name__ == '__main__':
  55. main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展