Browse Source

fix test

tags/v0.2.0
yunfan 6 years ago
parent
commit
8ea529404e
2 changed files with 7 additions and 18 deletions
  1. +3
    -2
      fastNLP/loader/config_loader.py
  2. +4
    -16
      test/modules/test_variational_rnn.py

+ 3
- 2
fastNLP/loader/config_loader.py View File

@@ -8,9 +8,10 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __init__(self, data_path):
def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
self.config = self.parse(super(ConfigLoader, self).load(data_path))
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):


+ 4
- 16
test/modules/test_variational_rnn.py View File

@@ -3,35 +3,23 @@ import unittest
import numpy as np
import torch

from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM
from fastNLP.modules.encoder.variational_rnn import VarLSTM


class TestMaskedRnn(unittest.TestCase):
def test_case_1(self):
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)


def test_case_2(self):
input_size = 12
batch = 16
hidden = 10
masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)

x = torch.randn((batch, input_size))
output, _ = masked_rnn.step(x)
self.assertEqual(tuple(output.shape), (batch, hidden))
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)

xx = torch.randn((batch, 32, input_size))
y, _ = masked_rnn(xx)
self.assertEqual(tuple(y.shape), (batch, 32, hidden))

xx = torch.randn((batch, 32, input_size))
mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx)
y, _ = masked_rnn(xx, mask=mask)
self.assertEqual(tuple(y.shape), (batch, 32, hidden))

Loading…
Cancel
Save