You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. def pack_sequence(tensor_seq, padding_value=0.0):
  5. if len(tensor_seq) <= 0:
  6. return
  7. length = [v.size(0) for v in tensor_seq]
  8. max_len = max(length)
  9. size = [len(tensor_seq), max_len]
  10. size.extend(list(tensor_seq[0].size()[1:]))
  11. ans = torch.Tensor(*size).fill_(padding_value)
  12. if tensor_seq[0].data.is_cuda:
  13. ans = ans.cuda()
  14. ans = Variable(ans)
  15. for i, v in enumerate(tensor_seq):
  16. ans[i, :length[i], :] = v
  17. return ans
  18. class HAN(nn.Module):
  19. def __init__(self, input_size, output_size,
  20. word_hidden_size, word_num_layers, word_context_size,
  21. sent_hidden_size, sent_num_layers, sent_context_size):
  22. super(HAN, self).__init__()
  23. self.word_layer = AttentionNet(input_size,
  24. word_hidden_size,
  25. word_num_layers,
  26. word_context_size)
  27. self.sent_layer = AttentionNet(2 * word_hidden_size,
  28. sent_hidden_size,
  29. sent_num_layers,
  30. sent_context_size)
  31. self.output_layer = nn.Linear(2 * sent_hidden_size, output_size)
  32. self.softmax = nn.LogSoftmax(dim=1)
  33. def forward(self, batch_doc):
  34. # input is a sequence of matrix
  35. doc_vec_list = []
  36. for doc in batch_doc:
  37. sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim)
  38. doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim)
  39. doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
  40. output = self.softmax(self.output_layer(doc_vec))
  41. return output
  42. class AttentionNet(nn.Module):
  43. def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
  44. super(AttentionNet, self).__init__()
  45. self.input_size = input_size
  46. self.gru_hidden_size = gru_hidden_size
  47. self.gru_num_layers = gru_num_layers
  48. self.context_vec_size = context_vec_size
  49. # Encoder
  50. self.gru = nn.GRU(input_size=input_size,
  51. hidden_size=gru_hidden_size,
  52. num_layers=gru_num_layers,
  53. batch_first=True,
  54. bidirectional=True)
  55. # Attention
  56. self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size)
  57. self.tanh = nn.Tanh()
  58. self.softmax = nn.Softmax(dim=1)
  59. # context vector
  60. self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
  61. self.context_vec.data.uniform_(-0.1, 0.1)
  62. def forward(self, inputs):
  63. # GRU part
  64. h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim)
  65. u = self.tanh(self.fc(h_t))
  66. # Attention part
  67. alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size)
  68. output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1)
  69. return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1)
  70. if __name__ == '__main__':
  71. '''
  72. Test the models correctness
  73. '''
  74. import numpy as np
  75. use_cuda = True
  76. net = HAN(input_size=200, output_size=5,
  77. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  78. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
  79. optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  80. criterion = nn.NLLLoss()
  81. test_time = 10
  82. batch_size = 64
  83. if use_cuda:
  84. net.cuda()
  85. print('test training')
  86. for step in range(test_time):
  87. x_data = [torch.randn(np.random.randint(1, 10), 200, 200) for i in range(batch_size)]
  88. y_data = torch.LongTensor([np.random.randint(0, 5) for i in range(batch_size)])
  89. if use_cuda:
  90. x_data = [x_i.cuda() for x_i in x_data]
  91. y_data = y_data.cuda()
  92. x = [Variable(x_i) for x_i in x_data]
  93. y = Variable(y_data)
  94. predict = net(x)
  95. loss = criterion(predict, y)
  96. optimizer.zero_grad()
  97. loss.backward()
  98. optimizer.step()
  99. print(loss.data[0])