diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index f1f3f2a8..ad338649 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -30,8 +30,18 @@ class DataSet(list): return self def index_field(self, field_name, vocab): - for ins in self: - ins.index_field(field_name, vocab) + if isinstance(field_name, str) and isinstance(vocab, Vocabulary): + field_list = [field_name] + vocab_list = [vocab] + else: + classes = (list, tuple) + assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) + field_list = field_name + vocab_list = vocab + + for name, vocabs in zip(field_list, vocab_list): + for ins in self: + ins.index_field(name, vocabs) return self def to_tensor(self, idx: int, padding_length: dict): diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6eedd214..d4bf475a 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -57,6 +57,20 @@ class SeqLabelEvaluator(Evaluator): return {"accuracy": float(accuracy)} +class SNLIEvaluator(Evaluator): + def __init__(self): + super(SNLIEvaluator, self).__init__() + + def __call__(self, predict, truth): + y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] + y_prob = torch.cat(y_prob, dim=0) + y_pred = torch.argmax(y_prob, dim=-1) + truth = [t['truth'] for t in truth] + y_true = torch.cat(truth, dim=0).view(-1) + acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) + return {"accuracy": acc} + + def _conver_numpy(x): """convert input data to numpy array diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 0e74145b..24aac951 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -83,6 +83,7 @@ class Tester(object): truth_list.append(batch_y) eval_results = self.evaluate(output_list, truth_list) print("[tester] {}".format(self.print_eval_results(eval_results))) + logger.info("[tester] {}".format(self.print_eval_results(eval_results))) def mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -131,3 +132,10 @@ class ClassificationTester(Tester): print( "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") super(ClassificationTester, self).__init__(**test_args) + + +class SNLITester(Tester): + def __init__(self, **test_args): + print( + "[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.") + super(SNLITester, self).__init__(**test_args) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 77b27b92..d884d4c7 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -18,6 +18,7 @@ def isiterable(p_object): return False return True + def check_build_vocab(func): def _wrapper(self, *args, **kwargs): if self.word2idx is None: @@ -28,6 +29,7 @@ def check_build_vocab(func): return func(self, *args, **kwargs) return _wrapper + class Vocabulary(object): """Use for word and index one to one mapping @@ -52,7 +54,6 @@ class Vocabulary(object): self.word2idx = None self.idx2word = None - def update(self, word): """add word or list of words into Vocabulary @@ -70,7 +71,6 @@ class Vocabulary(object): self.word_count[word] += 1 self.word2idx = None - def build_vocab(self): """build 'word to index' dict, and filter the word using `max_size` and `min_freq` """ @@ -163,3 +163,11 @@ class Vocabulary(object): """ self.__dict__.update(state) self.idx2word = None + + def __contains__(self, item): + """Check if a word in vocabulary. + + :param item: the word + :return: True or False + """ + return self.has_word(item) diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 5feb62a6..b58448fc 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.field import * + def convert_seq_dataset(data): """Create an DataSet instance that contains no labels. @@ -23,6 +24,7 @@ def convert_seq_dataset(data): dataset.append(Instance(word_seq=x)) return dataset + def convert_seq2tag_dataset(data): """Convert list of data into DataSet @@ -45,6 +47,7 @@ def convert_seq2tag_dataset(data): dataset.append(ins) return dataset + def convert_seq2seq_dataset(data): """Convert list of data into DataSet @@ -84,6 +87,7 @@ class DataSetLoader(BaseLoader): """ raise NotImplementedError + class RawDataSetLoader(DataSetLoader): def __init__(self): super(RawDataSetLoader, self).__init__() @@ -98,6 +102,7 @@ class RawDataSetLoader(DataSetLoader): def convert(self, data): return convert_seq_dataset(data) + class POSDataSetLoader(DataSetLoader): """Dataset Loader for POS Tag datasets. @@ -166,6 +171,7 @@ class POSDataSetLoader(DataSetLoader): """ return convert_seq2seq_dataset(data) + class TokenizeDataSetLoader(DataSetLoader): """ Data set loader for tokenization data sets @@ -339,6 +345,7 @@ class LMDataSetLoader(DataSetLoader): def convert(self, data): pass + class PeopleDailyCorpusLoader(DataSetLoader): """ People Daily Corpus: Chinese word segmentation, POS tag, NER @@ -390,3 +397,72 @@ class PeopleDailyCorpusLoader(DataSetLoader): def convert(self, data): pass + + +class SNLIDataSetLoader(DataSetLoader): + """A data set loader for SNLI data set. + + """ + + def __init__(self): + super(SNLIDataSetLoader, self).__init__() + + def load(self, path_list): + """ + + :param path_list: A list of file name, in the order of premise file, hypothesis file, and label file. + :return: data_set: A DataSet object. + """ + assert len(path_list) == 3 + line_set = [] + for file in path_list: + if not os.path.exists(file): + raise FileNotFoundError("file {} NOT found".format(file)) + + with open(file, 'r', encoding='utf-8') as f: + lines = f.readlines() + line_set.append(lines) + + premise_lines, hypothesis_lines, label_lines = line_set + assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) + + data_set = [] + for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): + p = premise.strip().split() + h = hypothesis.strip().split() + l = label.strip() + data_set.append([p, h, l]) + + return self.convert(data_set) + + def convert(self, data): + """Convert a 3D list to a DataSet object. + + :param data: A 3D tensor. + [ + [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], + [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], + ... + ] + :return: data_set: A DataSet object. + """ + + data_set = DataSet() + + for example in data: + p, h, l = example + # list, list, str + x1 = TextField(p, is_target=False) + x2 = TextField(h, is_target=False) + x1_len = TextField([1] * len(p), is_target=False) + x2_len = TextField([1] * len(h), is_target=False) + y = LabelField(l, is_target=True) + instance = Instance() + instance.add_field("premise", x1) + instance.add_field("hypothesis", x2) + instance.add_field("premise_len", x1_len) + instance.add_field("hypothesis_len", x2_len) + instance.add_field("truth", y) + data_set.append(instance) + + return data_set diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index b44c9851..2f61830f 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -6,11 +6,12 @@ import torch from fastNLP.loader.base_loader import BaseLoader from fastNLP.core.vocabulary import Vocabulary + class EmbedLoader(BaseLoader): """docstring for EmbedLoader""" - def __init__(self, data_path): - super(EmbedLoader, self).__init__(data_path) + def __init__(self): + super(EmbedLoader, self).__init__() @staticmethod def _load_glove(emb_file): @@ -55,15 +56,15 @@ class EmbedLoader(BaseLoader): :param emb_type: str, the pre-trained embedding format, support glove now :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding :param emb_pkl: str, the embedding pickle file. - :return embedding_np: numpy array of shape (len(word_dict), emb_dim) + :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) vocab: input vocab or vocab built by pre-train TODO: fragile code """ # If the embedding pickle exists, load it and return. if os.path.exists(emb_pkl): with open(emb_pkl, "rb") as f: - embedding_np, vocab = _pickle.load(f) - return embedding_np, vocab + embedding_tensor, vocab = _pickle.load(f) + return embedding_tensor, vocab # Otherwise, load the pre-trained embedding. pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) if vocab is None: @@ -71,14 +72,14 @@ class EmbedLoader(BaseLoader): vocab = Vocabulary() for w in pretrain.keys(): vocab.update(w) - embedding_np = torch.randn(len(vocab), emb_dim) + embedding_tensor = torch.randn(len(vocab), emb_dim) for w, v in pretrain.items(): if len(v.shape) > 1 or emb_dim != v.shape[0]: raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) if vocab.has_word(w): - embedding_np[vocab[w]] = v + embedding_tensor[vocab[w]] = v # save and return the result with open(emb_pkl, "wb") as f: - _pickle.dump((embedding_np, vocab), f) - return embedding_np, vocab + _pickle.dump((embedding_tensor, vocab), f) + return embedding_tensor, vocab diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index 2a4193b1..766dc225 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -1,12 +1,15 @@ import torch import torch.nn as nn from fastNLP.modules.utils import initial_parameter + + class MLP(nn.Module): - def __init__(self, size_layer, activation='relu' , initial_method = None): + def __init__(self, size_layer, activation='relu', initial_method=None): """Multilayer Perceptrons as a decoder - :param size_layer: list of int, define the size of MLP layers - :param activation: str or function, the activation function for hidden layers + :param size_layer: list of int, define the size of MLP layers. + :param activation: str or function, the activation function for hidden layers. + :param initial_method: str, the name of init method. .. note:: There is no activation function applying on output layer. @@ -23,7 +26,7 @@ class MLP(nn.Module): actives = { 'relu': nn.ReLU(), - 'tanh': nn.Tanh() + 'tanh': nn.Tanh(), } if activation in actives: self.hidden_active = actives[activation] @@ -31,7 +34,7 @@ class MLP(nn.Module): self.hidden_active = activation else: raise ValueError("should set activation correctly: {}".format(activation)) - initial_parameter(self, initial_method ) + initial_parameter(self, initial_method) def forward(self, x): for layer in self.hiddens: @@ -40,13 +43,11 @@ class MLP(nn.Module): return x - if __name__ == '__main__': - net1 = MLP([5,10,5]) - net2 = MLP([5,10,5], 'tanh') + net1 = MLP([5, 10, 5]) + net2 = MLP([5, 10, 5], 'tanh') for net in [net1, net2]: x = torch.randn(5, 5) y = net(x) print(x) print(y) - \ No newline at end of file diff --git a/fastNLP/modules/encoder/linear.py b/fastNLP/modules/encoder/linear.py index a7c5f6c3..399e15d3 100644 --- a/fastNLP/modules/encoder/linear.py +++ b/fastNLP/modules/encoder/linear.py @@ -1,6 +1,8 @@ import torch.nn as nn from fastNLP.modules.utils import initial_parameter + + class Linear(nn.Module): """ Linear module @@ -12,10 +14,11 @@ class Linear(nn.Module): bidirectional : If True, becomes a bidirectional RNN """ - def __init__(self, input_size, output_size, bias=True,initial_method = None ): + def __init__(self, input_size, output_size, bias=True, initial_method=None): super(Linear, self).__init__() self.linear = nn.Linear(input_size, output_size, bias) initial_parameter(self, initial_method) + def forward(self, x): x = self.linear(x) return x diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index e48960a8..a0b42442 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -14,16 +14,23 @@ class LSTM(nn.Module): bidirectional : If True, becomes a bidirectional RNN. Default: False. """ - def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, - initial_method=None): + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, + bidirectional=False, bias=True, initial_method=None, get_hidden=False): super(LSTM, self).__init__() - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) + self.get_hidden = get_hidden initial_parameter(self, initial_method) - def forward(self, x): - x, _ = self.lstm(x) - return x + def forward(self, x, h0=None, c0=None): + if h0 is not None and c0 is not None: + x, (ht, ct) = self.lstm(x, (h0, c0)) + else: + x, (ht, ct) = self.lstm(x) + if self.get_hidden: + return x, (ht, ct) + else: + return x if __name__ == "__main__": diff --git a/test/data_for_tests/config b/test/data_for_tests/config index 3f4ff7af..1180c97a 100644 --- a/test/data_for_tests/config +++ b/test/data_for_tests/config @@ -45,3 +45,28 @@ use_cuda = true learn_rate = 1e-3 momentum = 0.9 model_name = "class_model.pkl" + +[snli_trainer] +epochs = 5 +batch_size = 32 +validate = true +save_best_dev = true +use_cuda = true +learn_rate = 1e-4 +loss = "cross_entropy" +print_every_step = 1000 + +[snli_tester] +batch_size = 512 +use_cuda = true + +[snli_model] +model_name = "snli_model.pkl" +embed_dim = 300 +hidden_size = 300 +batch_first = true +dropout = 0.5 +gpu = true +embed_file = "./../data_for_tests/glove.840B.300d.txt" +embed_pkl = "./snli/embed.pkl" +examples = 0