diff --git a/test/modules/test_variational_rnn.py b/test/modules/test_variational_rnn.py new file mode 100644 index 00000000..cd265109 --- /dev/null +++ b/test/modules/test_variational_rnn.py @@ -0,0 +1,28 @@ + +import torch +import unittest + +from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM + +class TestMaskedRnn(unittest.TestCase): + def test_case_1(self): + masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) + x = torch.tensor([[[1.0], [2.0]]]) + print(x.size()) + y = masked_rnn(x) + mask = torch.tensor([[[1], [1]]]) + y = masked_rnn(x, mask=mask) + mask = torch.tensor([[[1], [0]]]) + y = masked_rnn(x, mask=mask) + + def test_case_2(self): + masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) + x = torch.tensor([[[1.0], [2.0]]]) + print(x.size()) + y = masked_rnn(x) + mask = torch.tensor([[[1], [1]]]) + y = masked_rnn(x, mask=mask) + xx = torch.tensor([[[1.0]]]) + #y, hidden = masked_rnn.step(xx) + #step() still has a bug + #y, hidden = masked_rnn.step(xx, mask=mask) \ No newline at end of file