Browse Source

add test code for testing masked rnn

tags/v0.1.0
xuyige 6 years ago
parent
commit
2bc54c6d17
1 changed files with 27 additions and 0 deletions
  1. +27
    -0
      test/modules/test_masked_rnn.py

+ 27
- 0
test/modules/test_masked_rnn.py View File

@@ -0,0 +1,27 @@

import torch
import unittest

from fastNLP.modules.encoder.masked_rnn import MaskedRNN

class TestMaskedRnn(unittest.TestCase):
def test_case_1(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)

def test_case_2(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
xx = torch.tensor([[[1.0]]])
y = masked_rnn.step(xx)
y = masked_rnn.step(xx, mask=mask)

Loading…
Cancel
Save