import pytest import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM @pytest.mark.torch class TestMaskedRnn: def test_case_1(self): masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) x = torch.tensor([[[1.0], [2.0]]]) print(x.size()) y = masked_rnn(x) def test_case_2(self): input_size = 12 batch = 16 hidden = 10 masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) xx = torch.randn((batch, 32, input_size)) y, _ = masked_rnn(xx) assert(tuple(y.shape) == (batch, 32, hidden))