|
|
@@ -0,0 +1,28 @@ |
|
|
|
|
|
|
|
import torch |
|
|
|
import unittest |
|
|
|
|
|
|
|
from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM |
|
|
|
|
|
|
|
class TestMaskedRnn(unittest.TestCase): |
|
|
|
def test_case_1(self): |
|
|
|
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) |
|
|
|
x = torch.tensor([[[1.0], [2.0]]]) |
|
|
|
print(x.size()) |
|
|
|
y = masked_rnn(x) |
|
|
|
mask = torch.tensor([[[1], [1]]]) |
|
|
|
y = masked_rnn(x, mask=mask) |
|
|
|
mask = torch.tensor([[[1], [0]]]) |
|
|
|
y = masked_rnn(x, mask=mask) |
|
|
|
|
|
|
|
def test_case_2(self): |
|
|
|
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) |
|
|
|
x = torch.tensor([[[1.0], [2.0]]]) |
|
|
|
print(x.size()) |
|
|
|
y = masked_rnn(x) |
|
|
|
mask = torch.tensor([[[1], [1]]]) |
|
|
|
y = masked_rnn(x, mask=mask) |
|
|
|
xx = torch.tensor([[[1.0]]]) |
|
|
|
#y, hidden = masked_rnn.step(xx) |
|
|
|
#step() still has a bug |
|
|
|
#y, hidden = masked_rnn.step(xx, mask=mask) |