@@ -1,13 +1,6 @@ | |||||
import _pickle | import _pickle | ||||
import os | import os | ||||
import numpy as np | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
# the first vocab in dict with the index = 5 | # the first vocab in dict with the index = 5 | ||||
@@ -53,258 +46,3 @@ def pickle_exist(pickle_path, pickle_name): | |||||
return True | return True | ||||
else: | else: | ||||
return False | return False | ||||
class Preprocessor(object): | |||||
"""Preprocessors are responsible for converting data of strings into data of indices. | |||||
During the pre-processing, the following pickle files will be built: | |||||
- "word2id.pkl", a Vocabulary object, mapping words to indices. | |||||
- "class2id.pkl", a Vocabulary object, mapping labels to indices. | |||||
- "data_train.pkl", a DataSet object for training | |||||
- "data_dev.pkl", a DataSet object for validation, if train_dev_split > 0. | |||||
- "data_test.pkl", a DataSet object for testing, if test_data is not None. | |||||
These four pickle files are expected to be saved in the given pickle directory once they are constructed. | |||||
Preprocessors will check if those files are already in the directory and will reuse them in future calls. | |||||
""" | |||||
def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False): | |||||
""" | |||||
:param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | |||||
several special tokens for sequence processing. | |||||
:param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this | |||||
is only available when label_is_seq is True. Default: False. | |||||
:param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||||
""" | |||||
print("Preprocessor is about to deprecate. Please use DataSet class.") | |||||
self.data_vocab = Vocabulary() | |||||
if label_is_seq is True: | |||||
if share_vocab is True: | |||||
self.label_vocab = self.data_vocab | |||||
else: | |||||
self.label_vocab = Vocabulary() | |||||
else: | |||||
self.label_vocab = Vocabulary(need_default=False) | |||||
self.character_vocab = Vocabulary(need_default=False) | |||||
self.add_char_field = add_char_field | |||||
@property | |||||
def vocab_size(self): | |||||
return len(self.data_vocab) | |||||
@property | |||||
def num_classes(self): | |||||
return len(self.label_vocab) | |||||
@property | |||||
def char_vocab_size(self): | |||||
if self.character_vocab is None: | |||||
self.build_char_dict() | |||||
return len(self.character_vocab) | |||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | |||||
"""Main pre-processing pipeline. | |||||
:param train_dev_data: three-level list, with either single label or multiple labels in a sample. | |||||
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | |||||
:param pickle_path: str, the path to save the pickle files. | |||||
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | |||||
:param cross_val: bool, whether to do cross validation. | |||||
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | |||||
:return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset. | |||||
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | |||||
is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | |||||
""" | |||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||||
self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | |||||
else: | |||||
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) | |||||
save_pickle(self.data_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(self.label_vocab, pickle_path, "class2id.pkl") | |||||
self.build_reverse_dict() | |||||
train_set = [] | |||||
dev_set = [] | |||||
if not cross_val: | |||||
if not pickle_exist(pickle_path, "data_train.pkl"): | |||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||||
split = int(len(train_dev_data) * train_dev_split) | |||||
data_dev = train_dev_data[: split] | |||||
data_train = train_dev_data[split:] | |||||
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) | |||||
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) | |||||
save_pickle(dev_set, pickle_path, "data_dev.pkl") | |||||
print("{} of the training data is split for validation. ".format(train_dev_split)) | |||||
else: | |||||
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) | |||||
save_pickle(train_set, pickle_path, "data_train.pkl") | |||||
else: | |||||
train_set = load_pickle(pickle_path, "data_train.pkl") | |||||
if pickle_exist(pickle_path, "data_dev.pkl"): | |||||
dev_set = load_pickle(pickle_path, "data_dev.pkl") | |||||
else: | |||||
# cross_val is True | |||||
if not pickle_exist(pickle_path, "data_train_0.pkl"): | |||||
# cross validation | |||||
data_cv = self.cv_split(train_dev_data, n_fold) | |||||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | |||||
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) | |||||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) | |||||
save_pickle( | |||||
data_train_cv, pickle_path, | |||||
"data_train_{}.pkl".format(i)) | |||||
save_pickle( | |||||
data_dev_cv, pickle_path, | |||||
"data_dev_{}.pkl".format(i)) | |||||
train_set.append(data_train_cv) | |||||
dev_set.append(data_dev_cv) | |||||
print("{}-fold cross validation.".format(n_fold)) | |||||
else: | |||||
for i in range(n_fold): | |||||
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | |||||
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | |||||
train_set.append(data_train_cv) | |||||
dev_set.append(data_dev_cv) | |||||
# prepare test data if provided | |||||
test_set = [] | |||||
if test_data is not None: | |||||
if not pickle_exist(pickle_path, "data_test.pkl"): | |||||
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) | |||||
save_pickle(test_set, pickle_path, "data_test.pkl") | |||||
# return preprocessed results | |||||
results = [train_set] | |||||
if cross_val or train_dev_split > 0: | |||||
results.append(dev_set) | |||||
if test_data: | |||||
results.append(test_set) | |||||
if len(results) == 1: | |||||
return results[0] | |||||
else: | |||||
return tuple(results) | |||||
def build_dict(self, data): | |||||
for example in data: | |||||
word, label = example | |||||
self.data_vocab.update(word) | |||||
self.label_vocab.update(label) | |||||
return self.data_vocab, self.label_vocab | |||||
def build_char_dict(self): | |||||
char_collection = set() | |||||
for word in self.data_vocab.word2idx: | |||||
if len(word) == 0: | |||||
continue | |||||
for ch in word: | |||||
if ch not in char_collection: | |||||
char_collection.add(ch) | |||||
self.character_vocab.update(list(char_collection)) | |||||
def build_reverse_dict(self): | |||||
self.data_vocab.build_reverse_vocab() | |||||
self.label_vocab.build_reverse_vocab() | |||||
def data_split(self, data, train_dev_split): | |||||
"""Split data into train and dev set.""" | |||||
split = int(len(data) * train_dev_split) | |||||
data_dev = data[: split] | |||||
data_train = data[split:] | |||||
return data_train, data_dev | |||||
def cv_split(self, data, n_fold): | |||||
"""Split data for cross validation. | |||||
:param data: list of string | |||||
:param n_fold: int | |||||
:return data_cv: | |||||
:: | |||||
[ | |||||
(data_train, data_dev), # 1st fold | |||||
(data_train, data_dev), # 2nd fold | |||||
... | |||||
] | |||||
""" | |||||
data_copy = data.copy() | |||||
np.random.shuffle(data_copy) | |||||
fold_size = round(len(data_copy) / n_fold) | |||||
data_cv = [] | |||||
for i in range(n_fold - 1): | |||||
start = i * fold_size | |||||
end = (i + 1) * fold_size | |||||
data_dev = data_copy[start:end] | |||||
data_train = data_copy[:start] + data_copy[end:] | |||||
data_cv.append((data_train, data_dev)) | |||||
start = (n_fold - 1) * fold_size | |||||
data_dev = data_copy[start:] | |||||
data_train = data_copy[:start] | |||||
data_cv.append((data_train, data_dev)) | |||||
return data_cv | |||||
def convert_to_dataset(self, data, vocab, label_vocab): | |||||
"""Convert list of indices into a DataSet object. | |||||
:param data: list. Entries are strings. | |||||
:param vocab: a dict, mapping string (token) to index (int). | |||||
:param label_vocab: a dict, mapping string (label) to index (int). | |||||
:return data_set: a DataSet object | |||||
""" | |||||
use_word_seq = False | |||||
use_label_seq = False | |||||
use_label_str = False | |||||
# construct a DataSet object and fill it with Instances | |||||
data_set = DataSet() | |||||
for example in data: | |||||
words, label = example[0], example[1] | |||||
instance = Instance() | |||||
if isinstance(words, list): | |||||
x = TextField(words, is_target=False) | |||||
instance.add_field("word_seq", x) | |||||
use_word_seq = True | |||||
else: | |||||
raise NotImplementedError("words is a {}".format(type(words))) | |||||
if isinstance(label, list): | |||||
y = TextField(label, is_target=True) | |||||
instance.add_field("label_seq", y) | |||||
use_label_seq = True | |||||
elif isinstance(label, str): | |||||
y = LabelField(label, is_target=True) | |||||
instance.add_field("label", y) | |||||
use_label_str = True | |||||
else: | |||||
raise NotImplementedError("label is a {}".format(type(label))) | |||||
data_set.append(instance) | |||||
# convert strings to indices | |||||
if use_word_seq: | |||||
data_set.index_field("word_seq", vocab) | |||||
if use_label_seq: | |||||
data_set.index_field("label_seq", label_vocab) | |||||
if use_label_str: | |||||
data_set.index_field("label", label_vocab) | |||||
return data_set | |||||
class SeqLabelPreprocess(Preprocessor): | |||||
def __init__(self): | |||||
print("[FastNLP warning] SeqLabelPreprocess is about to deprecate. Please use Preprocess directly.") | |||||
super(SeqLabelPreprocess, self).__init__() | |||||
class ClassPreprocess(Preprocessor): | |||||
def __init__(self): | |||||
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | |||||
super(ClassPreprocess, self).__init__() | |||||
@@ -13,69 +13,3 @@ class BaseModel(torch.nn.Module): | |||||
def fit(self, train_data, dev_data=None, **train_args): | def fit(self, train_data, dev_data=None, **train_args): | ||||
trainer = Trainer(**train_args) | trainer = Trainer(**train_args) | ||||
trainer.train(self, train_data, dev_data) | trainer.train(self, train_data, dev_data) | ||||
class Vocabulary(object): | |||||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` | |||||
instance also provides access to the `StringStore`, and owns underlying | |||||
data that is shared between `Doc` objects. | |||||
""" | |||||
def __init__(self): | |||||
"""Create the vocabulary. | |||||
RETURNS (Vocab): The newly constructed object. | |||||
""" | |||||
self.data_frame = None | |||||
class Document(object): | |||||
"""A sequence of Token objects. Access sentences and named entities, export | |||||
annotations to numpy arrays, losslessly serialize to compressed binary | |||||
strings. The `Doc` object holds an array of `Token` objects. The | |||||
Python-level `Token` and `Span` objects are views of this array, i.e. | |||||
they don't own the data themselves. -- spacy | |||||
""" | |||||
def __init__(self, vocab, words=None, spaces=None): | |||||
"""Create a Doc object. | |||||
vocab (Vocab): A vocabulary object, which must match any models you | |||||
want to use (e.g. tokenizer, parser, entity recognizer). | |||||
words (list or None): A list of unicode strings, to add to the document | |||||
as words. If `None`, defaults to empty list. | |||||
spaces (list or None): A list of boolean values, of the same length as | |||||
words. True means that the word is followed by a space, False means | |||||
it is not. If `None`, defaults to `[True]*len(words)` | |||||
user_data (dict or None): Optional extra data to attach to the Doc. | |||||
RETURNS (Doc): The newly constructed object. | |||||
""" | |||||
self.vocab = vocab | |||||
self.spaces = spaces | |||||
self.words = words | |||||
if spaces is None: | |||||
self.spaces = [True] * len(self.words) | |||||
elif len(spaces) != len(self.words): | |||||
raise ValueError("dismatch spaces and words") | |||||
def get_chunker(self, vocab): | |||||
return None | |||||
def push_back(self, vocab): | |||||
pass | |||||
class Token(object): | |||||
"""An individual token – i.e. a word, punctuation symbol, whitespace, | |||||
etc. | |||||
""" | |||||
def __init__(self, vocab, doc, offset): | |||||
"""Construct a `Token` object. | |||||
vocab (Vocabulary): A storage container for lexical types. | |||||
doc (Document): The parent document. | |||||
offset (int): The index of the token within the document. | |||||
""" | |||||
self.vocab = vocab | |||||
self.doc = doc | |||||
self.token = doc[offset] | |||||
self.i = offset | |||||
@@ -103,7 +103,7 @@ class CharLM(nn.Module): | |||||
x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | ||||
# [num_seq, seq_len, total_num_filters] | # [num_seq, seq_len, total_num_filters] | ||||
x, hidden = self.lstm(x) | |||||
x = self.lstm(x) | |||||
# [seq_len, num_seq, hidden_size] | # [seq_len, num_seq, hidden_size] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
@@ -1,12 +1,14 @@ | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
# from torch.nn.init import xavier_uniform | |||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
# from torch.nn.init import xavier_uniform | |||||
class ConvCharEmbedding(nn.Module): | class ConvCharEmbedding(nn.Module): | ||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None): | |||||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | |||||
""" | """ | ||||
Character Level Word Embedding | Character Level Word Embedding | ||||
:param char_emb_size: the size of character level embedding. Default: 50 | :param char_emb_size: the size of character level embedding. Default: 50 | ||||
@@ -21,7 +23,7 @@ class ConvCharEmbedding(nn.Module): | |||||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | ||||
for i in range(len(kernels))]) | for i in range(len(kernels))]) | ||||
initial_parameter(self,initial_method) | |||||
initial_parameter(self, initial_method) | |||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
@@ -56,7 +58,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | ||||
""" | """ | ||||
def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None): | |||||
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): | |||||
super(LSTMCharEmbedding, self).__init__() | super(LSTMCharEmbedding, self).__init__() | ||||
self.hidden_size = char_emb_size if hidden_size is None else hidden_size | self.hidden_size = char_emb_size if hidden_size is None else hidden_size | ||||
@@ -66,6 +68,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
bias=True, | bias=True, | ||||
batch_first=True) | batch_first=True) | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param x:[ n_batch*n_word, word_length, char_emb_size] | :param x:[ n_batch*n_word, word_length, char_emb_size] | ||||
@@ -79,20 +82,3 @@ class LSTMCharEmbedding(nn.Module): | |||||
_, hidden = self.lstm(x, (h0, c0)) | _, hidden = self.lstm(x, (h0, c0)) | ||||
return hidden[0].squeeze().unsqueeze(2) | return hidden[0].squeeze().unsqueeze(2) | ||||
if __name__ == "__main__": | |||||
batch_size = 128 | |||||
char_emb = 100 | |||||
word_length = 1 | |||||
x = torch.Tensor(batch_size, char_emb, word_length) | |||||
x = x.transpose(1, 2) | |||||
cce = ConvCharEmbedding(char_emb) | |||||
y = cce(x) | |||||
print("CNN Char Emb input: ", x.shape) | |||||
print("CNN Char Emb output: ", y.shape) # [128, 100] | |||||
lce = LSTMCharEmbedding(char_emb) | |||||
o = lce(x) | |||||
print("LSTM Char Emb input: ", x.shape) | |||||
print("LSTM Char Emb size: ", o.shape) |
@@ -1,24 +1,8 @@ | |||||
from fastNLP.core.loss import Loss | |||||
from fastNLP.core.preprocess import Preprocessor | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.loader.dataset_loader import LMDataSetLoader | |||||
from fastNLP.models.char_language_model import CharLM | |||||
PICKLE = "./save/" | PICKLE = "./save/" | ||||
def train(): | def train(): | ||||
loader = LMDataSetLoader() | |||||
train_data = loader.load() | |||||
pre = Preprocessor(label_is_seq=True, share_vocab=True) | |||||
train_set = pre.run(train_data, pickle_path=PICKLE) | |||||
model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size) | |||||
trainer = Trainer(task="language_model", loss=Loss("cross_entropy")) | |||||
trainer.train(model, train_set) | |||||
pass | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -1,72 +0,0 @@ | |||||
import os | |||||
import unittest | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
data = [ | |||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
[['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
[['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
[['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
[['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
[['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
] | |||||
class TestCase1(unittest.TestCase): | |||||
def test(self): | |||||
if os.path.exists("./save"): | |||||
for root, dirs, files in os.walk("./save", topdown=False): | |||||
for name in files: | |||||
os.remove(os.path.join(root, name)) | |||||
for name in dirs: | |||||
os.rmdir(os.path.join(root, name)) | |||||
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | |||||
pickle_path="./save") | |||||
self.assertEqual(len(result), 2) | |||||
self.assertEqual(type(result[0]), DataSet) | |||||
self.assertEqual(type(result[1]), DataSet) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
class TestCase2(unittest.TestCase): | |||||
def test(self): | |||||
if os.path.exists("./save"): | |||||
for root, dirs, files in os.walk("./save", topdown=False): | |||||
for name in files: | |||||
os.remove(os.path.join(root, name)) | |||||
for name in dirs: | |||||
os.rmdir(os.path.join(root, name)) | |||||
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | |||||
pickle_path="./save", train_dev_split=0.4, | |||||
cross_val=False) | |||||
self.assertEqual(len(result), 3) | |||||
self.assertEqual(type(result[0]), DataSet) | |||||
self.assertEqual(type(result[1]), DataSet) | |||||
self.assertEqual(type(result[2]), DataSet) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") | |||||
class TestCase3(unittest.TestCase): | |||||
def test(self): | |||||
num_folds = 2 | |||||
result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data, | |||||
pickle_path="./save", train_dev_split=0.4, | |||||
cross_val=True, n_fold=num_folds) | |||||
self.assertEqual(len(result), 2) | |||||
self.assertEqual(len(result[0]), num_folds) | |||||
self.assertEqual(len(result[1]), num_folds) | |||||
for data_set in result[0] + result[1]: | |||||
self.assertEqual(type(data_set), DataSet) | |||||
os.system("rm -rf save") | |||||
print("pickle path deleted") |
@@ -0,0 +1,25 @@ | |||||
import unittest | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.models.char_language_model import CharLM | |||||
class TestCharLM(unittest.TestCase): | |||||
def test_case_1(self): | |||||
char_emb_dim = 50 | |||||
word_emb_dim = 50 | |||||
vocab_size = 1000 | |||||
num_char = 24 | |||||
max_word_len = 21 | |||||
num_seq = 64 | |||||
seq_len = 32 | |||||
model = CharLM(char_emb_dim, word_emb_dim, vocab_size, num_char) | |||||
x = torch.from_numpy(np.random.randint(0, num_char, size=(num_seq, seq_len, max_word_len + 2))) | |||||
self.assertEqual(tuple(x.shape), (num_seq, seq_len, max_word_len + 2)) | |||||
y = model(x) | |||||
self.assertEqual(tuple(y.shape), (num_seq * seq_len, vocab_size)) |
@@ -0,0 +1,28 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.char_embedding import ConvCharEmbedding, LSTMCharEmbedding | |||||
class TestCharEmbed(unittest.TestCase): | |||||
def test_case_1(self): | |||||
batch_size = 128 | |||||
char_emb = 100 | |||||
word_length = 1 | |||||
x = torch.Tensor(batch_size, char_emb, word_length) | |||||
x = x.transpose(1, 2) | |||||
cce = ConvCharEmbedding(char_emb) | |||||
y = cce(x) | |||||
self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) | |||||
print("CNN Char Emb input: ", x.shape) | |||||
self.assertEqual(tuple(y.shape), (batch_size, char_emb, 1)) | |||||
print("CNN Char Emb output: ", y.shape) # [128, 100] | |||||
lce = LSTMCharEmbedding(char_emb) | |||||
o = lce(x) | |||||
self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) | |||||
print("LSTM Char Emb input: ", x.shape) | |||||
self.assertEqual(tuple(o.shape), (batch_size, char_emb, 1)) | |||||
print("LSTM Char Emb size: ", o.shape) |
@@ -1,9 +1,11 @@ | |||||
import unittest | |||||
import numpy as np | |||||
import torch | import torch | ||||
import unittest | |||||
from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM | from fastNLP.modules.encoder.variational_rnn import VarMaskedFastLSTM | ||||
class TestMaskedRnn(unittest.TestCase): | class TestMaskedRnn(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | ||||
@@ -16,13 +18,20 @@ class TestMaskedRnn(unittest.TestCase): | |||||
y = masked_rnn(x, mask=mask) | y = masked_rnn(x, mask=mask) | ||||
def test_case_2(self): | def test_case_2(self): | ||||
masked_rnn = VarMaskedFastLSTM(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) | |||||
x = torch.tensor([[[1.0], [2.0]]]) | |||||
print(x.size()) | |||||
y = masked_rnn(x) | |||||
mask = torch.tensor([[[1], [1]]]) | |||||
y = masked_rnn(x, mask=mask) | |||||
xx = torch.tensor([[[1.0]]]) | |||||
#y, hidden = masked_rnn.step(xx) | |||||
#step() still has a bug | |||||
#y, hidden = masked_rnn.step(xx, mask=mask) | |||||
input_size = 12 | |||||
batch = 16 | |||||
hidden = 10 | |||||
masked_rnn = VarMaskedFastLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||||
x = torch.randn((batch, input_size)) | |||||
output, _ = masked_rnn.step(x) | |||||
self.assertEqual(tuple(output.shape), (batch, hidden)) | |||||
xx = torch.randn((batch, 32, input_size)) | |||||
y, _ = masked_rnn(xx) | |||||
self.assertEqual(tuple(y.shape), (batch, 32, hidden)) | |||||
xx = torch.randn((batch, 32, input_size)) | |||||
mask = torch.from_numpy(np.random.randint(0, 2, size=(batch, 32))).to(xx) | |||||
y, _ = masked_rnn(xx, mask=mask) | |||||
self.assertEqual(tuple(y.shape), (batch, 32, hidden)) |