Browse Source

fix test

tags/v0.2.0
yunfan 6 years ago
parent
commit
8ea529404e
2 changed files with 7 additions and 18 deletions
  1. +3
    -2
      fastNLP/loader/config_loader.py
  2. +4
    -16
      test/modules/test_variational_rnn.py

+ 3
- 2
fastNLP/loader/config_loader.py View File

@@ -8,9 +8,10 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader): class ConfigLoader(BaseLoader):
"""loader for configuration files""" """loader for configuration files"""


def __init__(self, data_path):
def __init__(self, data_path=None):
super(ConfigLoader, self).__init__() super(ConfigLoader, self).__init__()
self.config = self.parse(super(ConfigLoader, self).load(data_path))
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))


@staticmethod @staticmethod
def parse(string): def parse(string):


+ 4
- 16
test/modules/test_variational_rnn.py View File

@@ -3,35 +3,23 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM
from fastNLP.modules.encoder.variational_rnn import VarLSTM




class TestMaskedRnn(unittest.TestCase): class TestMaskedRnn(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]]) x = torch.tensor([[[1.0], [2.0]]])
print(x.size()) print(x.size())
y = masked_rnn(x) y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)



def test_case_2(self): def test_case_2(self):
input_size = 12 input_size = 12
batch = 16 batch = 16
hidden = 10 hidden = 10
masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)

x = torch.randn((batch, input_size))
output, _ = masked_rnn.step(x)
self.assertEqual(tuple(output.shape), (batch, hidden))
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)


xx = torch.randn((batch, 32, input_size)) xx = torch.randn((batch, 32, input_size))
y, _ = masked_rnn(xx) y, _ = masked_rnn(xx)
self.assertEqual(tuple(y.shape), (batch, 32, hidden)) self.assertEqual(tuple(y.shape), (batch, 32, hidden))

xx = torch.randn((batch, 32, input_size))
mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx)
y, _ = masked_rnn(xx, mask=mask)
self.assertEqual(tuple(y.shape), (batch, 32, hidden))

Loading…
Cancel
Save