diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 6391ecac..cf3ac1a9 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -8,9 +8,10 @@ from fastNLP.loader.base_loader import BaseLoader class ConfigLoader(BaseLoader): """loader for configuration files""" - def __init__(self, data_path): + def __init__(self, data_path=None): super(ConfigLoader, self).__init__() - self.config = self.parse(super(ConfigLoader, self).load(data_path)) + if data_path is not None: + self.config = self.parse(super(ConfigLoader, self).load(data_path)) @staticmethod def parse(string): diff --git a/test/modules/test_variational_rnn.py b/test/modules/test_variational_rnn.py index b182fa1a..c3806f60 100644 --- a/test/modules/test_variational_rnn.py +++ b/test/modules/test_variational_rnn.py @@ -3,35 +3,23 @@ import unittest import numpy as np import torch -from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM +from fastNLP.modules.encoder.variational_rnn import VarLSTM class TestMaskedRnn(unittest.TestCase): def test_case_1(self): - masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) + masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) x = torch.tensor([[[1.0], [2.0]]]) print(x.size()) y = masked_rnn(x) - mask = torch.tensor([[[1], [1]]]) - y = masked_rnn(x, mask=mask) - mask = torch.tensor([[[1], [0]]]) - y = masked_rnn(x, mask=mask) + def test_case_2(self): input_size = 12 batch = 16 hidden = 10 - masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) - - x = torch.randn((batch, input_size)) - output, _ = masked_rnn.step(x) - self.assertEqual(tuple(output.shape), (batch, hidden)) + masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) xx = torch.randn((batch, 32, input_size)) y, _ = masked_rnn(xx) self.assertEqual(tuple(y.shape), (batch, 32, hidden)) - - xx = torch.randn((batch, 32, input_size)) - mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx) - y, _ = masked_rnn(xx, mask=mask) - self.assertEqual(tuple(y.shape), (batch, 32, hidden))