|
-
- import torch
- import unittest
-
- from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM
-
- class TestMaskedRnn(unittest.TestCase):
- def test_case_1(self):
- masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
- x = torch.tensor([[[1.0], [2.0]]])
- print(x.size())
- y = masked_rnn(x)
- mask = torch.tensor([[[1], [1]]])
- y = masked_rnn(x, mask=mask)
- mask = torch.tensor([[[1], [0]]])
- y = masked_rnn(x, mask=mask)
-
- def test_case_2(self):
- masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
- x = torch.tensor([[[1.0], [2.0]]])
- print(x.size())
- y = masked_rnn(x)
- mask = torch.tensor([[[1], [1]]])
- y = masked_rnn(x, mask=mask)
- xx = torch.tensor([[[1.0]]])
- #y, hidden = masked_rnn.step(xx)
- #step() still has a bug
- #y, hidden = masked_rnn.step(xx, mask=mask)
|