From beee8856895186c76071447d41b16d73162dd5d9 Mon Sep 17 00:00:00 2001 From: xuyige Date: Wed, 29 Aug 2018 15:28:22 +0800 Subject: [PATCH] add test code for testing variational rnn --- test/modules/test_variational_rnn.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/modules/test_variational_rnn.py diff --git a/test/modules/test_variational_rnn.py b/test/modules/test_variational_rnn.py new file mode 100644 index 00000000..cd265109 --- /dev/null +++ b/test/modules/test_variational_rnn.py @@ -0,0 +1,28 @@ + +import torch +import unittest + +from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM + +class TestMaskedRnn(unittest.TestCase): + def test_case_1(self): + masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) + x = torch.tensor([[[1.0], [2.0]]]) + print(x.size()) + y = masked_rnn(x) + mask = torch.tensor([[[1], [1]]]) + y = masked_rnn(x, mask=mask) + mask = torch.tensor([[[1], [0]]]) + y = masked_rnn(x, mask=mask) + + def test_case_2(self): + masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) + x = torch.tensor([[[1.0], [2.0]]]) + print(x.size()) + y = masked_rnn(x) + mask = torch.tensor([[[1], [1]]]) + y = masked_rnn(x, mask=mask) + xx = torch.tensor([[[1.0]]]) + #y, hidden = masked_rnn.step(xx) + #step() still has a bug + #y, hidden = masked_rnn.step(xx, mask=mask) \ No newline at end of file