Browse Source

Merge branch 'master' into dev

tags/v0.2.0
Coet GitHub 6 years ago
parent
commit
b80e5e8b29
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 360 additions and 31 deletions
  1. +12
    -2
      fastNLP/core/dataset.py
  2. +14
    -0
      fastNLP/core/metrics.py
  3. +8
    -0
      fastNLP/core/tester.py
  4. +14
    -2
      fastNLP/core/trainer.py
  5. +10
    -2
      fastNLP/core/vocabulary.py
  6. +79
    -0
      fastNLP/loader/dataset_loader.py
  7. +10
    -9
      fastNLP/loader/embed_loader.py
  8. +161
    -0
      fastNLP/models/snli.py
  9. +10
    -9
      fastNLP/modules/decoder/MLP.py
  10. +4
    -1
      fastNLP/modules/encoder/linear.py
  11. +13
    -6
      fastNLP/modules/encoder/lstm.py
  12. +25
    -0
      test/data_for_tests/config

+ 12
- 2
fastNLP/core/dataset.py View File

@@ -32,8 +32,18 @@ class DataSet(list):
return self

def index_field(self, field_name, vocab):
for ins in self:
ins.index_field(field_name, vocab)
if isinstance(field_name, str):
field_list = [field_name]
vocab_list = [vocab]
else:
classes = (list, tuple)
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab)
field_list = field_name
vocab_list = vocab

for name, vocabs in zip(field_list, vocab_list):
for ins in self:
ins.index_field(name, vocabs)
return self

def to_tensor(self, idx: int, padding_length: dict):


+ 14
- 0
fastNLP/core/metrics.py View File

@@ -57,6 +57,20 @@ class SeqLabelEvaluator(Evaluator):
return {"accuracy": float(accuracy)}


class SNLIEvaluator(Evaluator):
def __init__(self):
super(SNLIEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
truth = [t['truth'] for t in truth]
y_true = torch.cat(truth, dim=0).view(-1)
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0)
return {"accuracy": acc}


def _conver_numpy(x):
"""convert input data to numpy array



+ 8
- 0
fastNLP/core/tester.py View File

@@ -83,6 +83,7 @@ class Tester(object):
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
print("[tester] {}".format(self.print_eval_results(eval_results)))
logger.info("[tester] {}".format(self.print_eval_results(eval_results)))

def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -131,3 +132,10 @@ class ClassificationTester(Tester):
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
super(ClassificationTester, self).__init__(**test_args)


class SNLITester(Tester):
def __init__(self, **test_args):
print(
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.")
super(SNLITester, self).__init__(**test_args)

+ 14
- 2
fastNLP/core/trainer.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester
from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver

@@ -162,7 +162,7 @@ class Trainer(object):
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
kwargs["epoch"], step, loss.data, diff)
print(print_output)
logger.info(print_output)
@@ -292,3 +292,15 @@ class ClassificationTrainer(Trainer):

def _create_validator(self, valid_args):
return ClassificationTester(**valid_args)


class SNLITrainer(Trainer):
"""Trainer for text SNLI."""

def __init__(self, **train_args):
print(
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.")
super(SNLITrainer, self).__init__(**train_args)

def _create_validator(self, valid_args):
return SNLITester(**valid_args)

+ 10
- 2
fastNLP/core/vocabulary.py View File

@@ -18,6 +18,7 @@ def isiterable(p_object):
return False
return True


def check_build_vocab(func):
def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
@@ -28,6 +29,7 @@ def check_build_vocab(func):
return func(self, *args, **kwargs)
return _wrapper


class Vocabulary(object):
"""Use for word and index one to one mapping

@@ -52,7 +54,6 @@ class Vocabulary(object):
self.word2idx = None
self.idx2word = None


def update(self, word):
"""add word or list of words into Vocabulary

@@ -71,7 +72,6 @@ class Vocabulary(object):
self.word2idx = None
return self


def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq`
"""
@@ -164,3 +164,11 @@ class Vocabulary(object):
"""
self.__dict__.update(state)
self.idx2word = None

def __contains__(self, item):
"""Check if a word in vocabulary.

:param item: the word
:return: True or False
"""
return self.has_word(item)

+ 79
- 0
fastNLP/loader/dataset_loader.py View File

@@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.field import *


def convert_seq_dataset(data):
"""Create an DataSet instance that contains no labels.

@@ -23,6 +24,7 @@ def convert_seq_dataset(data):
dataset.append(Instance(word_seq=x))
return dataset


def convert_seq2tag_dataset(data):
"""Convert list of data into DataSet

@@ -45,6 +47,7 @@ def convert_seq2tag_dataset(data):
dataset.append(ins)
return dataset


def convert_seq2seq_dataset(data):
"""Convert list of data into DataSet

@@ -84,6 +87,7 @@ class DataSetLoader(BaseLoader):
"""
raise NotImplementedError


@DataSet.set_reader('read_raw')
class RawDataSetLoader(DataSetLoader):
def __init__(self):
@@ -99,6 +103,7 @@ class RawDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq_dataset(data)


@DataSet.set_reader('read_pos')
class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.
@@ -168,6 +173,7 @@ class POSDataSetLoader(DataSetLoader):
"""
return convert_seq2seq_dataset(data)


@DataSet.set_reader('read_tokenize')
class TokenizeDataSetLoader(DataSetLoader):
"""
@@ -227,6 +233,7 @@ class TokenizeDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2seq_dataset(data)


@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
@@ -265,6 +272,7 @@ class ClassDataSetLoader(DataSetLoader):
def convert(self, data):
return convert_seq2tag_dataset(data)


@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader):
"""loader for conll format files"""
@@ -306,6 +314,7 @@ class ConllLoader(DataSetLoader):
def convert(self, data):
pass


@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader
@@ -342,6 +351,7 @@ class LMDataSetLoader(DataSetLoader):
def convert(self, data):
pass


@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader):
"""
@@ -394,3 +404,72 @@ class PeopleDailyCorpusLoader(DataSetLoader):

def convert(self, data):
pass


class SNLIDataSetLoader(DataSetLoader):
"""A data set loader for SNLI data set.

"""

def __init__(self):
super(SNLIDataSetLoader, self).__init__()

def load(self, path_list):
"""

:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: data_set: A DataSet object.
"""
assert len(path_list) == 3
line_set = []
for file in path_list:
if not os.path.exists(file):
raise FileNotFoundError("file {} NOT found".format(file))

with open(file, 'r', encoding='utf-8') as f:
lines = f.readlines()
line_set.append(lines)

premise_lines, hypothesis_lines, label_lines = line_set
assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines)

data_set = []
for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines):
p = premise.strip().split()
h = hypothesis.strip().split()
l = label.strip()
data_set.append([p, h, l])

return self.convert(data_set)

def convert(self, data):
"""Convert a 3D list to a DataSet object.

:param data: A 3D tensor.
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]
:return: data_set: A DataSet object.
"""

data_set = DataSet()

for example in data:
p, h, l = example
# list, list, str
x1 = TextField(p, is_target=False)
x2 = TextField(h, is_target=False)
x1_len = TextField([1] * len(p), is_target=False)
x2_len = TextField([1] * len(h), is_target=False)
y = LabelField(l, is_target=True)
instance = Instance()
instance.add_field("premise", x1)
instance.add_field("hypothesis", x2)
instance.add_field("premise_len", x1_len)
instance.add_field("hypothesis_len", x2_len)
instance.add_field("truth", y)
data_set.append(instance)

return data_set

+ 10
- 9
fastNLP/loader/embed_loader.py View File

@@ -6,11 +6,12 @@ import torch
from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.vocabulary import Vocabulary


class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader"""

def __init__(self, data_path):
super(EmbedLoader, self).__init__(data_path)
def __init__(self):
super(EmbedLoader, self).__init__()

@staticmethod
def _load_glove(emb_file):
@@ -55,15 +56,15 @@ class EmbedLoader(BaseLoader):
:param emb_type: str, the pre-trained embedding format, support glove now
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding
:param emb_pkl: str, the embedding pickle file.
:return embedding_np: numpy array of shape (len(word_dict), emb_dim)
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
vocab: input vocab or vocab built by pre-train
TODO: fragile code
"""
# If the embedding pickle exists, load it and return.
if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f:
embedding_np, vocab = _pickle.load(f)
return embedding_np, vocab
embedding_tensor, vocab = _pickle.load(f)
return embedding_tensor, vocab
# Otherwise, load the pre-trained embedding.
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None:
@@ -71,14 +72,14 @@ class EmbedLoader(BaseLoader):
vocab = Vocabulary()
for w in pretrain.keys():
vocab.update(w)
embedding_np = torch.randn(len(vocab), emb_dim)
embedding_tensor = torch.randn(len(vocab), emb_dim)
for w, v in pretrain.items():
if len(v.shape) > 1 or emb_dim != v.shape[0]:
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
if vocab.has_word(w):
embedding_np[vocab[w]] = v
embedding_tensor[vocab[w]] = v

# save and return the result
with open(emb_pkl, "wb") as f:
_pickle.dump((embedding_np, vocab), f)
return embedding_np, vocab
_pickle.dump((embedding_tensor, vocab), f)
return embedding_tensor, vocab

+ 161
- 0
fastNLP/models/snli.py View File

@@ -0,0 +1,161 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder as Decoder, encoder as Encoder


my_inf = 10e12


class SNLI(BaseModel):
"""
PyTorch Network for SNLI.
"""

def __init__(self, args, init_embedding=None):
super(SNLI, self).__init__()
self.vocab_size = args["vocab_size"]
self.embed_dim = args["embed_dim"]
self.hidden_size = args["hidden_size"]
self.batch_first = args["batch_first"]
self.dropout = args["dropout"]
self.n_labels = args["num_classes"]
self.gpu = args["gpu"] and torch.cuda.is_available()

self.embedding = Encoder.embedding.Embedding(self.vocab_size, self.embed_dim, init_emb=init_embedding,
dropout=self.dropout)

self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size)

self.encoder = Encoder.LSTM(
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True,
batch_first=self.batch_first, bidirectional=True
)

self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size)

self.decoder = Encoder.LSTM(
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True,
batch_first=self.batch_first, bidirectional=True
)

self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh')

def forward(self, premise, hypothesis, premise_len, hypothesis_len):
""" Forward function

:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL), hidden size(H)].
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL), H].
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL].
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL].
:return: prediction: A Tensor of classification result: [B, n_labels(N)].
"""

premise0 = self.embedding_layer(self.embedding(premise))
hypothesis0 = self.embedding_layer(self.embedding(hypothesis))

_BP, _PSL, _HP = premise0.size()
_BH, _HSL, _HH = hypothesis0.size()
_BPL, _PLL = premise_len.size()
_HPL, _HLL = hypothesis_len.size()

assert _BP == _BH and _BPL == _HPL and _BP == _BPL
assert _HP == _HH
assert _PSL == _PLL and _HSL == _HLL

B, PL, H = premise0.size()
B, HL, H = hypothesis0.size()

# a0, (ah0, ac0) = self.encoder(premise) # a0: [B, PL, H * 2], ah0: [2, B, H]
# b0, (bh0, bc0) = self.encoder(hypothesis) # b0: [B, HL, H * 2]

a0 = self.encoder(premise0) # a0: [B, PL, H * 2]
b0 = self.encoder(hypothesis0) # b0: [B, HL, H * 2]

a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H]
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H]

ai, bi = self.calc_bi_attention(a, b, premise_len, hypothesis_len)

ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H]
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H]

f_ma = self.inference_layer(ma)
f_mb = self.inference_layer(mb)

vat = self.decoder(f_ma)
vbt = self.decoder(f_mb)

va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H]
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H]

# va_ave = torch.mean(va, dim=1) # va_ave: [B, H]
# va_max, va_arg_max = torch.max(va, dim=1) # va_max: [B, H]
# vb_ave = torch.mean(vb, dim=1) # vb_ave: [B, H]
# vb_max, vb_arg_max = torch.max(vb, dim=1) # vb_max: [B, H]

va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H]

v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H]

# v_mlp = F.tanh(self.mlp_layer1(v)) # v_mlp: [B, H]
# prediction = self.mlp_layer2(v_mlp) # prediction: [B, N]

prediction = F.tanh(self.output(v)) # prediction: [B, N]

return prediction

@staticmethod
def calc_bi_attention(in_x1, in_x2, x1_len, x2_len):

# in_x1: [batch_size, x1_seq_len, hidden_size]
# in_x2: [batch_size, x2_seq_len, hidden_size]
# x1_len: [batch_size, x1_seq_len]
# x2_len: [batch_size, x2_seq_len]

assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]

batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1]

in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]

attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]

a_mask = x1_len.le(0.5).float() * -my_inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -my_inf
b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]

attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]

out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]

return out_x1, out_x2

@staticmethod
def mean_pooling(tensor, mask, dim=0):
masks = mask.view(mask.size(0), mask.size(1), -1).float()
return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1)

@staticmethod
def max_pooling(tensor, mask, dim=0):
masks = mask.view(mask.size(0), mask.size(1), -1)
masks = masks.expand(-1, -1, tensor.size(2)).float()
return torch.max(tensor + masks.le(0.5).float() * -my_inf, dim=dim)

+ 10
- 9
fastNLP/modules/decoder/MLP.py View File

@@ -1,12 +1,15 @@
import torch
import torch.nn as nn
from fastNLP.modules.utils import initial_parameter


class MLP(nn.Module):
def __init__(self, size_layer, activation='relu' , initial_method = None):
def __init__(self, size_layer, activation='relu', initial_method=None):
"""Multilayer Perceptrons as a decoder

:param size_layer: list of int, define the size of MLP layers
:param activation: str or function, the activation function for hidden layers
:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.

.. note::
There is no activation function applying on output layer.
@@ -23,7 +26,7 @@ class MLP(nn.Module):

actives = {
'relu': nn.ReLU(),
'tanh': nn.Tanh()
'tanh': nn.Tanh(),
}
if activation in actives:
self.hidden_active = actives[activation]
@@ -31,7 +34,7 @@ class MLP(nn.Module):
self.hidden_active = activation
else:
raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method )
initial_parameter(self, initial_method)

def forward(self, x):
for layer in self.hiddens:
@@ -40,13 +43,11 @@ class MLP(nn.Module):
return x



if __name__ == '__main__':
net1 = MLP([5,10,5])
net2 = MLP([5,10,5], 'tanh')
net1 = MLP([5, 10, 5])
net2 = MLP([5, 10, 5], 'tanh')
for net in [net1, net2]:
x = torch.randn(5, 5)
y = net(x)
print(x)
print(y)

+ 4
- 1
fastNLP/modules/encoder/linear.py View File

@@ -1,6 +1,8 @@
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter


class Linear(nn.Module):
"""
Linear module
@@ -12,10 +14,11 @@ class Linear(nn.Module):
bidirectional : If True, becomes a bidirectional RNN
"""

def __init__(self, input_size, output_size, bias=True,initial_method = None ):
def __init__(self, input_size, output_size, bias=True, initial_method=None):
super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias)
initial_parameter(self, initial_method)

def forward(self, x):
x = self.linear(x)
return x

+ 13
- 6
fastNLP/modules/encoder/lstm.py View File

@@ -14,16 +14,23 @@ class LSTM(nn.Module):
bidirectional : If True, becomes a bidirectional RNN. Default: False.
"""

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False,
initial_method=None):
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.get_hidden = get_hidden
initial_parameter(self, initial_method)

def forward(self, x):
x, _ = self.lstm(x)
return x
def forward(self, x, h0=None, c0=None):
if h0 is not None and c0 is not None:
x, (ht, ct) = self.lstm(x, (h0, c0))
else:
x, (ht, ct) = self.lstm(x)
if self.get_hidden:
return x, (ht, ct)
else:
return x


if __name__ == "__main__":


+ 25
- 0
test/data_for_tests/config View File

@@ -45,3 +45,28 @@ use_cuda = true
learn_rate = 1e-3
momentum = 0.9
model_name = "class_model.pkl"

[snli_trainer]
epochs = 5
batch_size = 32
validate = true
save_best_dev = true
use_cuda = true
learn_rate = 1e-4
loss = "cross_entropy"
print_every_step = 1000

[snli_tester]
batch_size = 512
use_cuda = true

[snli_model]
model_name = "snli_model.pkl"
embed_dim = 300
hidden_size = 300
batch_first = true
dropout = 0.5
gpu = true
embed_file = "./../data_for_tests/glove.840B.300d.txt"
embed_pkl = "./snli/embed.pkl"
examples = 0

Loading…
Cancel
Save