diff --git a/fastNLP/action/inference.py b/fastNLP/action/inference.py index 94e5be19..1a1c4d2c 100644 --- a/fastNLP/action/inference.py +++ b/fastNLP/action/inference.py @@ -3,9 +3,6 @@ class Inference(object): This is an interface focusing on predicting output based on trained models. It does not care about evaluations of the model. - Possible improvements: - - use batch to make use of GPU - """ def __init__(self): diff --git a/fastNLP/action/metrics.py b/fastNLP/action/metrics.py new file mode 100644 index 00000000..7c8a6bec --- /dev/null +++ b/fastNLP/action/metrics.py @@ -0,0 +1,8 @@ +""" +To do: + 设计评判结果的各种指标。如果涉及向量,使用numpy。 + 参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics + 建议是每种metric写成一个函数 (由Tester的evaluate函数调用) + 参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 + +""" diff --git a/fastNLP/action/optimizor.py b/fastNLP/action/optimizor.py index 89a6c7eb..becdc499 100644 --- a/fastNLP/action/optimizor.py +++ b/fastNLP/action/optimizor.py @@ -1,15 +1,16 @@ from torch import optim -def get_torch_optimizor(params, alg_name='sgd', **args): - ''' - construct pytorch optimizor by algorithm's name - optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc + +def get_torch_optimizer(params, alg_name='sgd', **args): + """ + construct PyTorch optimizer by algorithm's name + optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc usage: - optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01) + optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) + + """ - ''' - name = alg_name.lower() if name == 'adadelta': return optim.Adadelta(params, **args) @@ -28,22 +29,22 @@ def get_torch_optimizor(params, alg_name='sgd', **args): elif name == 'rprop': return optim.Rprop(params, **args) elif name == 'sgd': - #SGD's parameter lr is required + # SGD's parameter lr is required if 'lr' not in args: args['lr'] = 0.01 return optim.SGD(params, **args) elif name == 'sparseadam': return optim.SparseAdam(params, **args) else: - raise TypeError('no such optimizor named {}'.format(alg_name)) + raise TypeError('no such optimizer named {}'.format(alg_name)) -# example usage if __name__ == '__main__': from torch.nn.modules import Linear + net = Linear(2, 5) - test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3) + test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) print(test1) - test2 = get_torch_optimizor(net.parameters(), 'SGD') - print(test2) \ No newline at end of file + test2 = get_torch_optimizer(net.parameters(), 'SGD') + print(test2) diff --git a/fastNLP/action/tester.py b/fastNLP/action/tester.py index e4cca9e5..6ba8c9fa 100644 --- a/fastNLP/action/tester.py +++ b/fastNLP/action/tester.py @@ -1,8 +1,8 @@ import _pickle +import os import numpy as np import torch -import os from fastNLP.action.action import Action from fastNLP.action.action import RandomSampler, Batchifier @@ -108,7 +108,7 @@ class BaseTester(Action): raise NotImplementedError @property - def matrices(self): + def metrics(self): raise NotImplementedError def mode(self, model, test=True): @@ -163,7 +163,7 @@ class POSTester(BaseTester): accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] return [loss.data, accuracy] - def matrices(self): + def metrics(self): batch_loss = np.mean([x[0] for x in self.eval_history]) batch_accuracy = np.mean([x[1] for x in self.eval_history]) return batch_loss, batch_accuracy @@ -173,7 +173,7 @@ class POSTester(BaseTester): This is called by Trainer to print evaluation on dev set. :return print_str: str """ - loss, accuracy = self.matrices() + loss, accuracy = self.metrics() return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) @@ -309,7 +309,7 @@ class ClassTester(BaseTester): y_prob = torch.nn.functional.softmax(y_logit, dim=-1) return [y_prob, y_true] - def matrices(self): + def metrics(self): """Compute accuracy.""" y_prob, y_true = zip(*self.eval_history) y_prob = torch.cat(y_prob, dim=0) diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 45049e9d..419fceca 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -181,7 +181,7 @@ class BaseTrainer(Action): """ raise NotImplementedError - def batchify(self, data): + def batchify(self, data, output_length=True): """ 1. Perform batching from data and produce a batch of training data. 2. Add padding. @@ -194,13 +194,18 @@ class BaseTrainer(Action): ] :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] + seq_len: list. The length of the pre-padded sequence, if output_length is True. """ indices = next(self.iterator) batch = [data[idx] for idx in indices] batch_x = [sample[0] for sample in batch] batch_y = [sample[1] for sample in batch] - batch_x = self.pad(batch_x) - return batch_x, batch_y + batch_x_pad = self.pad(batch_x) + if output_length: + seq_len = [len(x) for x in batch_x] + return batch_x_pad, batch_y, seq_len + else: + return batch_x_pad, batch_y @staticmethod def pad(batch, fill=0): @@ -245,7 +250,10 @@ class ToyTrainer(BaseTrainer): return data_train, data_dev, 0, 1 def mode(self, test=False): - self.model.mode(test) + if test: + self.model.eval() + else: + self.model.train() def data_forward(self, network, x): return network(x) @@ -333,7 +341,7 @@ class POSTrainer(BaseTrainer): return loss def best_eval_result(self, validator): - loss, accuracy = validator.matrices() + loss, accuracy = validator.metrics() if accuracy > self.best_accuracy: self.best_accuracy = accuracy return True diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 65daafed..dc5640f1 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader): class POSDatasetLoader(DatasetLoader): - """loader for pos data sets""" - + """Dataset Loader for POS Tag datasets. + + In these datasets, each line are divided by '\t' + while the first Col is the vocabulary and the second + Col is the label. + Different sentence are divided by an empty line. + e.g: + Tom label1 + and label2 + Jerry label1 + . label3 + Hello label4 + world label5 + ! label3 + In this file, there are two sentence "Tom and Jerry ." + and "Hello world !". Each word has its own label from label1 + to label5. + """ def __init__(self, data_name, data_path): super(POSDatasetLoader, self).__init__(data_name, data_path) @@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader): return line def load_lines(self): - assert (os.path.exists(self.data_path)) + """ + :return data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + """ with open(self.data_path, "r", encoding="utf-8") as f: lines = f.readlines() - return lines + return self.parse(lines) + + @staticmethod + def parse(lines): + data = [] + sentence = [] + for line in lines: + line = line.strip() + if len(line) > 1: + sentence.append(line.split('\t')) + else: + words = [] + labels = [] + for tokens in sentence: + words.append(tokens[0]) + labels.append(tokens[1]) + data.append([words, labels]) + sentence = [] + if len(sentence) != 0: + words = [] + labels = [] + for tokens in sentence: + words.append(tokens[0]) + labels.append(tokens[1]) + data.append([words, labels]) + return data class ClassDatasetLoader(DatasetLoader): @@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader): with open(self.data_path, "r", encoding="utf=8") as f: text = " ".join(f.readlines()) return text.strip().split() + + +if __name__ == "__main__": + data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines() + for example in data: + for w, l in zip(example[0], example[1]): + print(w, l) diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index ee5524d3..7cd91f9c 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -7,6 +7,10 @@ DEFAULT_RESERVED_LABEL = ['', '', ''] # dict index = 2~4 +DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + # the first vocab in dict with the index = 5 @@ -24,69 +28,86 @@ class BasePreprocess(object): class POSPreprocess(BasePreprocess): """ This class are used to preprocess the pos datasets. - In these datasets, each line are divided by '\t' - while the first Col is the vocabulary and the second - Col is the label. - Different sentence are divided by an empty line. - e.g: - Tom label1 - and label2 - Jerry label1 - . label3 - Hello label4 - world label5 - ! label3 - In this file, there are two sentence "Tom and Jerry ." - and "Hello world !". Each word has its own label from label1 - to label5. - """ + """ - def __init__(self, data, pickle_path): + def __init__(self, data, pickle_path="./", train_dev_split=0): + """ + Preprocess pipeline, including building mapping from words to index, from index to words, + from labels/classes to index, from index to labels/classes. + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :param pickle_path: str, the directory to the pickle files. Default: "./" + :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. + + To do: + 1. simplify __init__ + """ super(POSPreprocess, self).__init__(data, pickle_path) - self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} - self.label_dict = None - self.data = data self.pickle_path = pickle_path - self.build_dict(data) - if not self.pickle_exist("word2id.pkl"): - self.word_dict.update(self.word2id(data)) - file_name = os.path.join(self.pickle_path, "word2id.pkl") - with open(file_name, "wb") as f: - _pickle.dump(self.word_dict, f) + if self.pickle_exist("word2id.pkl"): + # load word2index because the construction of the following objects needs it + with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f: + self.word2index = _pickle.load(f) + else: + self.word2index, self.label2index = self.build_dict(data) + with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: + _pickle.dump(self.word2index, f) - self.vocab_size = self.id2word() - self.class2id() - self.num_classes = self.id2class() - self.embedding() - self.data_train() - self.data_dev() - self.data_test() + if self.pickle_exist("class2id.pkl"): + with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f: + self.label2index = _pickle.load(f) + else: + with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: + _pickle.dump(self.label2index, f) + #something will be wrong if word2id.pkl is found but class2id.pkl is not found + + if not self.pickle_exist("id2word.pkl"): + index2word = self.build_reverse_dict(self.word2index) + with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f: + _pickle.dump(index2word, f) + + if not self.pickle_exist("id2class.pkl"): + index2label = self.build_reverse_dict(self.label2index) + with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: + _pickle.dump(index2label, f) + + if not self.pickle_exist("data_train.pkl"): + data_train = self.to_index(data) + if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"): + data_dev = data_train[: int(len(data_train) * train_dev_split)] + with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f: + _pickle.dump(data_dev, f) + with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f: + _pickle.dump(data_train, f) def build_dict(self, data): """ Add new words with indices into self.word_dict, new labels with indices into self.label_dict. - :param data: list of list [word, label] + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return word2index: dict of {str, int} + label2index: dict of {str, int} """ - - self.label_dict = {} - for line in data: - line = line.strip() - if len(line) <= 1: - continue - tokens = line.split('\t') - - if tokens[0] not in self.word_dict: - # add (word, index) into the dict - self.word_dict[tokens[0]] = len(self.word_dict) - - # for label in tokens[1: ]: - if tokens[1] not in self.label_dict: - self.label_dict[tokens[1]] = len(self.label_dict) + label2index = {} + word2index = DEFAULT_WORD_TO_INDEX + for example in data: + for word, label in zip(example[0], example[1]): + if word not in word2index: + word2index[word] = len(word2index) + if label not in label2index: + label2index[label] = len(label2index) + return word2index, label2index def pickle_exist(self, pickle_name): """ @@ -101,90 +122,38 @@ class POSPreprocess(BasePreprocess): else: return False - def word2id(self): - if self.pickle_exist("word2id.pkl"): - return - # nothing will be done if word2id.pkl exists - - file_name = os.path.join(self.pickle_path, "word2id.pkl") - with open(file_name, "wb") as f: - _pickle.dump(self.word_dict, f) - - def id2word(self): - if self.pickle_exist("id2word.pkl"): - file_name = os.path.join(self.pickle_path, "id2word.pkl") - id2word_dict = _pickle.load(open(file_name, "rb")) - return len(id2word_dict) - # nothing will be done if id2word.pkl exists + def build_reverse_dict(self, word_dict): + id2word = {word_dict[w]: w for w in word_dict} + return id2word - id2word_dict = {} - for word in self.word_dict: - id2word_dict[self.word_dict[word]] = word - file_name = os.path.join(self.pickle_path, "id2word.pkl") - with open(file_name, "wb") as f: - _pickle.dump(id2word_dict, f) - return len(id2word_dict) - - def class2id(self): - if self.pickle_exist("class2id.pkl"): - return - # nothing will be done if class2id.pkl exists - - file_name = os.path.join(self.pickle_path, "class2id.pkl") - with open(file_name, "wb") as f: - _pickle.dump(self.label_dict, f) - - def id2class(self): - if self.pickle_exist("id2class.pkl"): - file_name = os.path.join(self.pickle_path, "id2class.pkl") - id2class_dict = _pickle.load(open(file_name, "rb")) - return len(id2class_dict) - # nothing will be done if id2class.pkl exists - - id2class_dict = {} - for label in self.label_dict: - id2class_dict[self.label_dict[label]] = label - file_name = os.path.join(self.pickle_path, "id2class.pkl") - with open(file_name, "wb") as f: - _pickle.dump(id2class_dict, f) - return len(id2class_dict) - - def embedding(self): - if self.pickle_exist("embedding.pkl"): - return - # nothing will be done if embedding.pkl exists - - def data_train(self): - if self.pickle_exist("data_train.pkl"): - return - # nothing will be done if data_train.pkl exists - - data_train = [] - sentence = [] - for w in self.data: - w = w.strip() - if len(w) <= 1: - wid = [] - lid = [] - for i in range(len(sentence)): - # if sentence[i][0]=="": - # print("") - wid.append(self.word_dict[sentence[i][0]]) - lid.append(self.label_dict[sentence[i][1]]) - data_train.append((wid, lid)) - sentence = [] - continue - sentence.append(w.split('\t')) - - file_name = os.path.join(self.pickle_path, "data_train.pkl") - with open(file_name, "wb") as f: - _pickle.dump(data_train, f) - - def data_dev(self): - pass - - def data_test(self): - pass + def to_index(self, data): + """ + Convert word strings and label strings into indices. + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return data_index: the shape of data, but each string is replaced by its corresponding index + """ + data_index = [] + for example in data: + word_list = [] + label_list = [] + for word, label in zip(example[0], example[1]): + word_list.append(self.word2index[word]) + label_list.append(self.label2index[label]) + data_index.append([word_list, label_list]) + return data_index + + @property + def vocab_size(self): + return len(self.word2index) + + @property + def num_classes(self): + return len(self.label2index) class ClassPreprocess(BasePreprocess): diff --git a/fastNLP/modules/aggregation/kmax_pool.py b/fastNLP/modules/aggregation/kmax_pool.py index 17fa9248..4d71130e 100644 --- a/fastNLP/modules/aggregation/kmax_pool.py +++ b/fastNLP/modules/aggregation/kmax_pool.py @@ -9,7 +9,7 @@ import torch.nn as nn class KMaxPool(nn.Module): """K max-pooling module.""" - def __init__(self, k): + def __init__(self, k=1): super(KMaxPool, self).__init__() self.k = k diff --git a/fastNLP/modules/aggregation/linear_attention.py b/fastNLP/modules/aggregation/linear_attention.py deleted file mode 100644 index 8f761c7a..00000000 --- a/fastNLP/modules/aggregation/linear_attention.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastNLP.modules.aggregation.attention import Attention - - -class LinearAttention(Attention): - def __init__(self, normalize=False): - super(LinearAttention, self).__init__(normalize) - - def _atten_forward(self, query, memory): - raise NotImplementedError diff --git a/fastNLP/modules/aggregation/self_attention.py b/fastNLP/modules/aggregation/self_attention.py index f3581b44..aeaef4db 100644 --- a/fastNLP/modules/aggregation/self_attention.py +++ b/fastNLP/modules/aggregation/self_attention.py @@ -8,14 +8,15 @@ class SelfAttention(nn.Module): Self Attention Module. Args: - input_size : the size for the input vector - d_a : the width of weight matrix - r : the number of encoded vectors + input_size: int, the size for the input vector + dim: int, the width of weight matrix. + num_vec: int, the number of encoded vectors """ - def __init__(self, input_size, d_a, r): + + def __init__(self, input_size, dim=10, num_vec=10): super(SelfAttention, self).__init__() - self.W_s1 = nn.Parameter(torch.randn(d_a, input_size), requires_grad=True) - self.W_s2 = nn.Parameter(torch.randn(r, d_a), requires_grad=True) + self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) + self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) self.softmax = nn.Softmax(dim=2) self.tanh = nn.Tanh() diff --git a/fastNLP/modules/encoder/char_embedding.py b/fastNLP/modules/encoder/char_embedding.py index ba70445b..72680e5b 100644 --- a/fastNLP/modules/encoder/char_embedding.py +++ b/fastNLP/modules/encoder/char_embedding.py @@ -5,13 +5,15 @@ from torch import nn class ConvCharEmbedding(nn.Module): - def __init__(self, char_emb_size, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): + def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): """ Character Level Word Embedding - :param char_emb_size: the size of character level embedding, + :param char_emb_size: the size of character level embedding. Default: 50 say 26 characters, each embedded to 50 dim vector, then the input_size is 50. - :param feature_maps: table of feature maps (for each kernel width) - :param kernels: table of kernel widths + :param feature_maps: tuple of int. The length of the tuple is the number of convolution operations + over characters. The i-th integer is the number of filters (dim of out channels) for the i-th + convolution. + :param kernels: tuple of int. The width of each kernel. """ super(ConvCharEmbedding, self).__init__() self.convs = nn.ModuleList([ @@ -23,29 +25,35 @@ class ConvCharEmbedding(nn.Module): :param x: [batch_size * sent_length, word_length, char_emb_size] :return: [batch_size * sent_length, sum(feature_maps), 1] """ - x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) # [batch_size*sent_length, channel, width, height] - x = x.transpose(2, 3) # [batch_size*sent_length, channel, height, width] + x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) + # [batch_size*sent_length, channel, width, height] + x = x.transpose(2, 3) + # [batch_size*sent_length, channel, height, width] return self.convolute(x).unsqueeze(2) def convolute(self, x): feats = [] for conv in self.convs: - y = conv(x) # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] - y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] + y = conv(x) + # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] + y = torch.squeeze(y, 2) + # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] y = F.tanh(y) - y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]] + y, __ = torch.max(y, 2) + # [batch_size*sent_length, feature_maps[i]] feats.append(y) return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)] class LSTMCharEmbedding(nn.Module): """ - Character Level Word Embedding with LSTM - :param char_emb_size: the size of character level embedding, + Character Level Word Embedding with LSTM with a single layer. + :param char_emb_size: int, the size of character level embedding. Default: 50 say 26 characters, each embedded to 50 dim vector, then the input_size is 50. + :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. """ - def __init__(self, char_emb_size, hidden_size=None): + def __init__(self, char_emb_size=50, hidden_size=None): super(LSTMCharEmbedding, self).__init__() self.hidden_size = char_emb_size if hidden_size is None else hidden_size diff --git a/fastNLP/modules/encoder/conv.py b/fastNLP/modules/encoder/conv.py index a3a572d9..1aeedbd5 100644 --- a/fastNLP/modules/encoder/conv.py +++ b/fastNLP/modules/encoder/conv.py @@ -2,12 +2,14 @@ # encoding: utf-8 import torch.nn as nn +from torch.nn.init import xavier_uniform # import torch.nn.functional as F class Conv(nn.Module): """ Basic 1-d convolution module. + initialize with xavier_uniform """ def __init__(self, in_channels, out_channels, kernel_size, @@ -23,6 +25,7 @@ class Conv(nn.Module): dilation=dilation, groups=groups, bias=bias) + xavier_uniform(self.conv.weight) def forward(self, x): return self.conv(x) # [N,C,L] diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 17c8f20a..1e11adfe 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -7,12 +7,13 @@ class Lookuptable(nn.Module): Args: nums : the size of the lookup table - dims : the size of each vector + dims : the size of each vector. Default: 50. padding_idx : pads the tensor with zeros whenever it encounters this index sparse : If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used """ - def __init__(self, nums, dims, padding_idx=0, sparse=False): + + def __init__(self, nums, dims=50, padding_idx=0, sparse=False): super(Lookuptable, self).__init__() self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 6d110fca..2d8c14f4 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -8,11 +8,12 @@ class Lstm(nn.Module): Args: input_size : input size hidden_size : hidden size - num_layers : number of hidden layers - dropout : dropout rate - bidirectional : If True, becomes a bidirectional RNN + num_layers : number of hidden layers. Default: 1 + dropout : dropout rate. Default: 0.5 + bidirectional : If True, becomes a bidirectional RNN. Default: False. """ - def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional): + + def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5, bidirectional=False): super(Lstm, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, dropout=dropout, bidirectional=bidirectional) diff --git a/test/test_loader.py b/test/test_loader.py index 58e5dfe5..b18a2fcf 100644 --- a/test/test_loader.py +++ b/test/test_loader.py @@ -1,9 +1,23 @@ import unittest +from fastNLP.loader.dataset_loader import POSDatasetLoader -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) + +class TestPreprocess(unittest.TestCase): + def test_case_1(self): + data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], + ["Hello", "world", "!"], ["T", "F", "F"]] + pickle_path = "./data_for_tests/" + # POSPreprocess(data, pickle_path) + + +class TestDatasetLoader(unittest.TestCase): + def test_case_1(self): + data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" + lines = data.split("\n") + answer = POSDatasetLoader.parse(lines) + truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] + self.assertListEqual(answer, truth, "POS Dataset Loader") if __name__ == '__main__':