You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_variational_rnn.py 792 B

123456789101112131415161718192021222324252627
  1. import pytest
  2. import numpy as np
  3. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  4. if _NEED_IMPORT_TORCH:
  5. import torch
  6. from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM
  7. @pytest.mark.torch
  8. class TestMaskedRnn:
  9. def test_case_1(self):
  10. masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
  11. x = torch.tensor([[[1.0], [2.0]]])
  12. print(x.size())
  13. y = masked_rnn(x)
  14. def test_case_2(self):
  15. input_size = 12
  16. batch = 16
  17. hidden = 10
  18. masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)
  19. xx = torch.randn((batch, 32, input_size))
  20. y, _ = masked_rnn(xx)
  21. assert(tuple(y.shape) == (batch, 32, hidden))