Browse Source

clean some codes and fix some bugs

tags/v0.2.0
xuyige 5 years ago
parent
commit
b43d333738
10 changed files with 182 additions and 29 deletions
  1. +12
    -2
      fastNLP/core/dataset.py
  2. +14
    -0
      fastNLP/core/metrics.py
  3. +8
    -0
      fastNLP/core/tester.py
  4. +10
    -2
      fastNLP/core/vocabulary.py
  5. +76
    -0
      fastNLP/loader/dataset_loader.py
  6. +10
    -9
      fastNLP/loader/embed_loader.py
  7. +10
    -9
      fastNLP/modules/decoder/MLP.py
  8. +4
    -1
      fastNLP/modules/encoder/linear.py
  9. +13
    -6
      fastNLP/modules/encoder/lstm.py
  10. +25
    -0
      test/data_for_tests/config

+ 12
- 2
fastNLP/core/dataset.py View File

@@ -30,8 +30,18 @@ class DataSet(list):
return self return self


def index_field(self, field_name, vocab): def index_field(self, field_name, vocab):
for ins in self:
ins.index_field(field_name, vocab)
if isinstance(field_name, str) and isinstance(vocab, Vocabulary):
field_list = [field_name]
vocab_list = [vocab]
else:
classes = (list, tuple)
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab)
field_list = field_name
vocab_list = vocab

for name, vocabs in zip(field_list, vocab_list):
for ins in self:
ins.index_field(name, vocabs)
return self return self


def to_tensor(self, idx: int, padding_length: dict): def to_tensor(self, idx: int, padding_length: dict):


+ 14
- 0
fastNLP/core/metrics.py View File

@@ -57,6 +57,20 @@ class SeqLabelEvaluator(Evaluator):
return {"accuracy": float(accuracy)} return {"accuracy": float(accuracy)}




class SNLIEvaluator(Evaluator):
def __init__(self):
super(SNLIEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
truth = [t['truth'] for t in truth]
y_true = torch.cat(truth, dim=0).view(-1)
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0)
return {"accuracy": acc}


def _conver_numpy(x): def _conver_numpy(x):
"""convert input data to numpy array """convert input data to numpy array




+ 8
- 0
fastNLP/core/tester.py View File

@@ -83,6 +83,7 @@ class Tester(object):
truth_list.append(batch_y) truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list) eval_results = self.evaluate(output_list, truth_list)
print("[tester] {}".format(self.print_eval_results(eval_results))) print("[tester] {}".format(self.print_eval_results(eval_results)))
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))


def mode(self, model, is_test=False): def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.
@@ -131,3 +132,10 @@ class ClassificationTester(Tester):
print( print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
super(ClassificationTester, self).__init__(**test_args) super(ClassificationTester, self).__init__(**test_args)


class SNLITester(Tester):
def __init__(self, **test_args):
print(
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.")
super(SNLITester, self).__init__(**test_args)

+ 10
- 2
fastNLP/core/vocabulary.py View File

@@ -18,6 +18,7 @@ def isiterable(p_object):
return False return False
return True return True



def check_build_vocab(func): def check_build_vocab(func):
def _wrapper(self, *args, **kwargs): def _wrapper(self, *args, **kwargs):
if self.word2idx is None: if self.word2idx is None:
@@ -28,6 +29,7 @@ def check_build_vocab(func):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return _wrapper return _wrapper



class Vocabulary(object): class Vocabulary(object):
"""Use for word and index one to one mapping """Use for word and index one to one mapping


@@ -52,7 +54,6 @@ class Vocabulary(object):
self.word2idx = None self.word2idx = None
self.idx2word = None self.idx2word = None



def update(self, word): def update(self, word):
"""add word or list of words into Vocabulary """add word or list of words into Vocabulary


@@ -70,7 +71,6 @@ class Vocabulary(object):
self.word_count[word] += 1 self.word_count[word] += 1
self.word2idx = None self.word2idx = None



def build_vocab(self): def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` """build 'word to index' dict, and filter the word using `max_size` and `min_freq`
""" """
@@ -163,3 +163,11 @@ class Vocabulary(object):
""" """
self.__dict__.update(state) self.__dict__.update(state)
self.idx2word = None self.idx2word = None

def __contains__(self, item):
"""Check if a word in vocabulary.

:param item: the word
:return: True or False
"""
return self.has_word(item)

+ 76
- 0
fastNLP/loader/dataset_loader.py View File

@@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.field import * from fastNLP.core.field import *



def convert_seq_dataset(data): def convert_seq_dataset(data):
"""Create an DataSet instance that contains no labels. """Create an DataSet instance that contains no labels.


@@ -23,6 +24,7 @@ def convert_seq_dataset(data):
dataset.append(Instance(word_seq=x)) dataset.append(Instance(word_seq=x))
return dataset return dataset



def convert_seq2tag_dataset(data): def convert_seq2tag_dataset(data):
"""Convert list of data into DataSet """Convert list of data into DataSet


@@ -45,6 +47,7 @@ def convert_seq2tag_dataset(data):
dataset.append(ins) dataset.append(ins)
return dataset return dataset



def convert_seq2seq_dataset(data): def convert_seq2seq_dataset(data):
"""Convert list of data into DataSet """Convert list of data into DataSet


@@ -84,6 +87,7 @@ class DataSetLoader(BaseLoader):
""" """
raise NotImplementedError raise NotImplementedError



class RawDataSetLoader(DataSetLoader): class RawDataSetLoader(DataSetLoader):
def __init__(self): def __init__(self):
super(RawDataSetLoader, self).__init__() super(RawDataSetLoader, self).__init__()
@@ -98,6 +102,7 @@ class RawDataSetLoader(DataSetLoader):
def convert(self, data): def convert(self, data):
return convert_seq_dataset(data) return convert_seq_dataset(data)



class POSDataSetLoader(DataSetLoader): class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets. """Dataset Loader for POS Tag datasets.


@@ -166,6 +171,7 @@ class POSDataSetLoader(DataSetLoader):
""" """
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)



class TokenizeDataSetLoader(DataSetLoader): class TokenizeDataSetLoader(DataSetLoader):
""" """
Data set loader for tokenization data sets Data set loader for tokenization data sets
@@ -339,6 +345,7 @@ class LMDataSetLoader(DataSetLoader):
def convert(self, data): def convert(self, data):
pass pass



class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """
People Daily Corpus: Chinese word segmentation, POS tag, NER People Daily Corpus: Chinese word segmentation, POS tag, NER
@@ -390,3 +397,72 @@ class PeopleDailyCorpusLoader(DataSetLoader):


def convert(self, data): def convert(self, data):
pass pass


class SNLIDataSetLoader(DataSetLoader):
"""A data set loader for SNLI data set.

"""

def __init__(self):
super(SNLIDataSetLoader, self).__init__()

def load(self, path_list):
"""

:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: data_set: A DataSet object.
"""
assert len(path_list) == 3
line_set = []
for file in path_list:
if not os.path.exists(file):
raise FileNotFoundError("file {} NOT found".format(file))

with open(file, 'r', encoding='utf-8') as f:
lines = f.readlines()
line_set.append(lines)

premise_lines, hypothesis_lines, label_lines = line_set
assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines)

data_set = []
for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines):
p = premise.strip().split()
h = hypothesis.strip().split()
l = label.strip()
data_set.append([p, h, l])

return self.convert(data_set)

def convert(self, data):
"""Convert a 3D list to a DataSet object.

:param data: A 3D tensor.
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]
:return: data_set: A DataSet object.
"""

data_set = DataSet()

for example in data:
p, h, l = example
# list, list, str
x1 = TextField(p, is_target=False)
x2 = TextField(h, is_target=False)
x1_len = TextField([1] * len(p), is_target=False)
x2_len = TextField([1] * len(h), is_target=False)
y = LabelField(l, is_target=True)
instance = Instance()
instance.add_field("premise", x1)
instance.add_field("hypothesis", x2)
instance.add_field("premise_len", x1_len)
instance.add_field("hypothesis_len", x2_len)
instance.add_field("truth", y)
data_set.append(instance)

return data_set

+ 10
- 9
fastNLP/loader/embed_loader.py View File

@@ -6,11 +6,12 @@ import torch
from fastNLP.loader.base_loader import BaseLoader from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary



class EmbedLoader(BaseLoader): class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader""" """docstring for EmbedLoader"""


def __init__(self, data_path):
super(EmbedLoader, self).__init__(data_path)
def __init__(self):
super(EmbedLoader, self).__init__()


@staticmethod @staticmethod
def _load_glove(emb_file): def _load_glove(emb_file):
@@ -55,15 +56,15 @@ class EmbedLoader(BaseLoader):
:param emb_type: str, the pre-trained embedding format, support glove now :param emb_type: str, the pre-trained embedding format, support glove now
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding
:param emb_pkl: str, the embedding pickle file. :param emb_pkl: str, the embedding pickle file.
:return embedding_np: numpy array of shape (len(word_dict), emb_dim)
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
vocab: input vocab or vocab built by pre-train vocab: input vocab or vocab built by pre-train
TODO: fragile code TODO: fragile code
""" """
# If the embedding pickle exists, load it and return. # If the embedding pickle exists, load it and return.
if os.path.exists(emb_pkl): if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f: with open(emb_pkl, "rb") as f:
embedding_np, vocab = _pickle.load(f)
return embedding_np, vocab
embedding_tensor, vocab = _pickle.load(f)
return embedding_tensor, vocab
# Otherwise, load the pre-trained embedding. # Otherwise, load the pre-trained embedding.
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None: if vocab is None:
@@ -71,14 +72,14 @@ class EmbedLoader(BaseLoader):
vocab = Vocabulary() vocab = Vocabulary()
for w in pretrain.keys(): for w in pretrain.keys():
vocab.update(w) vocab.update(w)
embedding_np = torch.randn(len(vocab), emb_dim)
embedding_tensor = torch.randn(len(vocab), emb_dim)
for w, v in pretrain.items(): for w, v in pretrain.items():
if len(v.shape) > 1 or emb_dim != v.shape[0]: if len(v.shape) > 1 or emb_dim != v.shape[0]:
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
if vocab.has_word(w): if vocab.has_word(w):
embedding_np[vocab[w]] = v
embedding_tensor[vocab[w]] = v


# save and return the result # save and return the result
with open(emb_pkl, "wb") as f: with open(emb_pkl, "wb") as f:
_pickle.dump((embedding_np, vocab), f)
return embedding_np, vocab
_pickle.dump((embedding_tensor, vocab), f)
return embedding_tensor, vocab

+ 10
- 9
fastNLP/modules/decoder/MLP.py View File

@@ -1,12 +1,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter


class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, size_layer, activation='relu' , initial_method = None):
def __init__(self, size_layer, activation='relu', initial_method=None):
"""Multilayer Perceptrons as a decoder """Multilayer Perceptrons as a decoder


:param size_layer: list of int, define the size of MLP layers
:param activation: str or function, the activation function for hidden layers
:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.


.. note:: .. note::
There is no activation function applying on output layer. There is no activation function applying on output layer.
@@ -23,7 +26,7 @@ class MLP(nn.Module):


actives = { actives = {
'relu': nn.ReLU(), 'relu': nn.ReLU(),
'tanh': nn.Tanh()
'tanh': nn.Tanh(),
} }
if activation in actives: if activation in actives:
self.hidden_active = actives[activation] self.hidden_active = actives[activation]
@@ -31,7 +34,7 @@ class MLP(nn.Module):
self.hidden_active = activation self.hidden_active = activation
else: else:
raise ValueError("should set activation correctly: {}".format(activation)) raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method )
initial_parameter(self, initial_method)


def forward(self, x): def forward(self, x):
for layer in self.hiddens: for layer in self.hiddens:
@@ -40,13 +43,11 @@ class MLP(nn.Module):
return x return x





if __name__ == '__main__': if __name__ == '__main__':
net1 = MLP([5,10,5])
net2 = MLP([5,10,5], 'tanh')
net1 = MLP([5, 10, 5])
net2 = MLP([5, 10, 5], 'tanh')
for net in [net1, net2]: for net in [net1, net2]:
x = torch.randn(5, 5) x = torch.randn(5, 5)
y = net(x) y = net(x)
print(x) print(x)
print(y) print(y)

+ 4
- 1
fastNLP/modules/encoder/linear.py View File

@@ -1,6 +1,8 @@
import torch.nn as nn import torch.nn as nn


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter


class Linear(nn.Module): class Linear(nn.Module):
""" """
Linear module Linear module
@@ -12,10 +14,11 @@ class Linear(nn.Module):
bidirectional : If True, becomes a bidirectional RNN bidirectional : If True, becomes a bidirectional RNN
""" """


def __init__(self, input_size, output_size, bias=True,initial_method = None ):
def __init__(self, input_size, output_size, bias=True, initial_method=None):
super(Linear, self).__init__() super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias) self.linear = nn.Linear(input_size, output_size, bias)
initial_parameter(self, initial_method) initial_parameter(self, initial_method)

def forward(self, x): def forward(self, x):
x = self.linear(x) x = self.linear(x)
return x return x

+ 13
- 6
fastNLP/modules/encoder/lstm.py View File

@@ -14,16 +14,23 @@ class LSTM(nn.Module):
bidirectional : If True, becomes a bidirectional RNN. Default: False. bidirectional : If True, becomes a bidirectional RNN. Default: False.
""" """


def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False,
initial_method=None):
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__() super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional) dropout=dropout, bidirectional=bidirectional)
self.get_hidden = get_hidden
initial_parameter(self, initial_method) initial_parameter(self, initial_method)


def forward(self, x):
x, _ = self.lstm(x)
return x
def forward(self, x, h0=None, c0=None):
if h0 is not None and c0 is not None:
x, (ht, ct) = self.lstm(x, (h0, c0))
else:
x, (ht, ct) = self.lstm(x)
if self.get_hidden:
return x, (ht, ct)
else:
return x




if __name__ == "__main__": if __name__ == "__main__":


+ 25
- 0
test/data_for_tests/config View File

@@ -45,3 +45,28 @@ use_cuda = true
learn_rate = 1e-3 learn_rate = 1e-3
momentum = 0.9 momentum = 0.9
model_name = "class_model.pkl" model_name = "class_model.pkl"

[snli_trainer]
epochs = 5
batch_size = 32
validate = true
save_best_dev = true
use_cuda = true
learn_rate = 1e-4
loss = "cross_entropy"
print_every_step = 1000

[snli_tester]
batch_size = 512
use_cuda = true

[snli_model]
model_name = "snli_model.pkl"
embed_dim = 300
hidden_size = 300
batch_first = true
dropout = 0.5
gpu = true
embed_file = "./../data_for_tests/glove.840B.300d.txt"
embed_pkl = "./snli/embed.pkl"
examples = 0

Loading…
Cancel
Save