You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.py 1.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class RNN(nn.Module):
  5. def __init__(self, diminput, dimoutput, dimhidden, nsteps):
  6. super(RNN, self).__init__()
  7. self.diminput = diminput
  8. self.dimoutput = dimoutput
  9. self.dimhidden = dimhidden
  10. self.nsteps = nsteps
  11. self.fc1 = nn.Linear(diminput, dimhidden)
  12. self.fc2 = nn.Linear(dimhidden*2, dimhidden)
  13. self.fc3 = nn.Linear(dimhidden, dimoutput)
  14. def forward(self, x):
  15. last_state = torch.zeros((x.shape[0], self.dimhidden)).to(x.device)
  16. for i in range(self.nsteps):
  17. t = i % self.nsteps
  18. index = torch.Tensor([idx for idx in range(
  19. t*self.diminput, (t+1)*self.diminput)]).long().to(x.device)
  20. cur_x = torch.index_select(x, 1, index)
  21. h = self.fc1(cur_x)
  22. s = torch.cat([h, last_state], axis=1)
  23. s = self.fc2(s)
  24. last_state = F.relu(s)
  25. final_state = last_state
  26. y = self.fc3(final_state)
  27. return y
  28. def rnn(diminput, dimoutput, dimhidden, nsteps):
  29. return RNN(diminput, dimoutput, dimhidden, nsteps)

分布式深度学习系统

Contributors (1)