You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedder.py 2.4 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """
  2. Embedder class.
  3. """
  4. import torch
  5. import torch.nn as nn
  6. class Embedder(nn.Module):
  7. """
  8. Composite embedding layer.
  9. """
  10. def __init__(self,
  11. hidden_dim,
  12. num_token_embeddings,
  13. num_pos_embeddings,
  14. num_type_embeddings,
  15. num_turn_embeddings,
  16. padding_idx=None,
  17. dropout=0.1,
  18. pos_trainable=False):
  19. super(Embedder, self).__init__()
  20. self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim)
  21. self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim)
  22. self.pos_embedding.weight.requires_grad = pos_trainable
  23. self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim)
  24. self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim)
  25. self.dropout_layer = nn.Dropout(p=dropout)
  26. # follow the default xavier_uniform initializer in paddle version
  27. # otherwise, there are bugs for dec_probs computation in weight typing setting
  28. # default norm initializer in nn.Embedding in pytorch, which samples larger values
  29. nn.init.xavier_uniform_(self.token_embedding.weight)
  30. nn.init.xavier_uniform_(self.pos_embedding.weight)
  31. nn.init.xavier_uniform_(self.type_embedding.weight)
  32. nn.init.xavier_uniform_(self.turn_embedding.weight)
  33. return
  34. def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None):
  35. embed = self.token_embedding(token_inp)
  36. if pos_inp is not None:
  37. embed += self.pos_embedding(pos_inp)
  38. if type_inp is not None:
  39. embed += self.type_embedding(type_inp)
  40. if turn_inp is not None:
  41. embed += self.turn_embedding(turn_inp)
  42. embed = self.dropout_layer(embed)
  43. return embed
  44. def main():
  45. import numpy as np
  46. model = Embedder(10, 20, 20, 20, 20)
  47. token_inp = torch.tensor(
  48. np.random.randint(0, 19, [10, 10]).astype('int64'))
  49. pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  50. type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  51. turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64'))
  52. out = model(token_inp, pos_inp, type_inp, turn_inp)
  53. print(out)
  54. if __name__ == '__main__':
  55. main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展