import torch import torch.nn as nn import torch.nn.functional as F class RNN(nn.Module): def __init__(self, diminput, dimoutput, dimhidden, nsteps): super(RNN, self).__init__() self.diminput = diminput self.dimoutput = dimoutput self.dimhidden = dimhidden self.nsteps = nsteps self.fc1 = nn.Linear(diminput, dimhidden) self.fc2 = nn.Linear(dimhidden*2, dimhidden) self.fc3 = nn.Linear(dimhidden, dimoutput) def forward(self, x): last_state = torch.zeros((x.shape[0], self.dimhidden)).to(x.device) for i in range(self.nsteps): t = i % self.nsteps index = torch.Tensor([idx for idx in range( t*self.diminput, (t+1)*self.diminput)]).long().to(x.device) cur_x = torch.index_select(x, 1, index) h = self.fc1(cur_x) s = torch.cat([h, last_state], axis=1) s = self.fc2(s) last_state = F.relu(s) final_state = last_state y = self.fc3(final_state) return y def rnn(diminput, dimoutput, dimhidden, nsteps): return RNN(diminput, dimoutput, dimhidden, nsteps)