@@ -32,8 +32,18 @@ class DataSet(list): | |||
return self | |||
def index_field(self, field_name, vocab): | |||
for ins in self: | |||
ins.index_field(field_name, vocab) | |||
if isinstance(field_name, str): | |||
field_list = [field_name] | |||
vocab_list = [vocab] | |||
else: | |||
classes = (list, tuple) | |||
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) | |||
field_list = field_name | |||
vocab_list = vocab | |||
for name, vocabs in zip(field_list, vocab_list): | |||
for ins in self: | |||
ins.index_field(name, vocabs) | |||
return self | |||
def to_tensor(self, idx: int, padding_length: dict): | |||
@@ -57,6 +57,20 @@ class SeqLabelEvaluator(Evaluator): | |||
return {"accuracy": float(accuracy)} | |||
class SNLIEvaluator(Evaluator): | |||
def __init__(self): | |||
super(SNLIEvaluator, self).__init__() | |||
def __call__(self, predict, truth): | |||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||
y_prob = torch.cat(y_prob, dim=0) | |||
y_pred = torch.argmax(y_prob, dim=-1) | |||
truth = [t['truth'] for t in truth] | |||
y_true = torch.cat(truth, dim=0).view(-1) | |||
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) | |||
return {"accuracy": acc} | |||
def _conver_numpy(x): | |||
"""convert input data to numpy array | |||
@@ -83,6 +83,7 @@ class Tester(object): | |||
truth_list.append(batch_y) | |||
eval_results = self.evaluate(output_list, truth_list) | |||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||
logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||
def mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -131,3 +132,10 @@ class ClassificationTester(Tester): | |||
print( | |||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||
super(ClassificationTester, self).__init__(**test_args) | |||
class SNLITester(Tester): | |||
def __init__(self, **test_args): | |||
print( | |||
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.") | |||
super(SNLITester, self).__init__(**test_args) |
@@ -10,7 +10,7 @@ from fastNLP.core.loss import Loss | |||
from fastNLP.core.metrics import Evaluator | |||
from fastNLP.core.optimizer import Optimizer | |||
from fastNLP.core.sampler import RandomSampler | |||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester | |||
from fastNLP.saver.logger import create_logger | |||
from fastNLP.saver.model_saver import ModelSaver | |||
@@ -162,7 +162,7 @@ class Trainer(object): | |||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||
end = time.time() | |||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||
kwargs["epoch"], step, loss.data, diff) | |||
print(print_output) | |||
logger.info(print_output) | |||
@@ -292,3 +292,15 @@ class ClassificationTrainer(Trainer): | |||
def _create_validator(self, valid_args): | |||
return ClassificationTester(**valid_args) | |||
class SNLITrainer(Trainer): | |||
"""Trainer for text SNLI.""" | |||
def __init__(self, **train_args): | |||
print( | |||
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.") | |||
super(SNLITrainer, self).__init__(**train_args) | |||
def _create_validator(self, valid_args): | |||
return SNLITester(**valid_args) |
@@ -18,6 +18,7 @@ def isiterable(p_object): | |||
return False | |||
return True | |||
def check_build_vocab(func): | |||
def _wrapper(self, *args, **kwargs): | |||
if self.word2idx is None: | |||
@@ -28,6 +29,7 @@ def check_build_vocab(func): | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
class Vocabulary(object): | |||
"""Use for word and index one to one mapping | |||
@@ -52,7 +54,6 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
self.idx2word = None | |||
def update(self, word): | |||
"""add word or list of words into Vocabulary | |||
@@ -71,7 +72,6 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
return self | |||
def build_vocab(self): | |||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||
""" | |||
@@ -164,3 +164,11 @@ class Vocabulary(object): | |||
""" | |||
self.__dict__.update(state) | |||
self.idx2word = None | |||
def __contains__(self, item): | |||
"""Check if a word in vocabulary. | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return self.has_word(item) |
@@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.field import * | |||
def convert_seq_dataset(data): | |||
"""Create an DataSet instance that contains no labels. | |||
@@ -23,6 +24,7 @@ def convert_seq_dataset(data): | |||
dataset.append(Instance(word_seq=x)) | |||
return dataset | |||
def convert_seq2tag_dataset(data): | |||
"""Convert list of data into DataSet | |||
@@ -45,6 +47,7 @@ def convert_seq2tag_dataset(data): | |||
dataset.append(ins) | |||
return dataset | |||
def convert_seq2seq_dataset(data): | |||
"""Convert list of data into DataSet | |||
@@ -84,6 +87,7 @@ class DataSetLoader(BaseLoader): | |||
""" | |||
raise NotImplementedError | |||
@DataSet.set_reader('read_raw') | |||
class RawDataSetLoader(DataSetLoader): | |||
def __init__(self): | |||
@@ -99,6 +103,7 @@ class RawDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
@DataSet.set_reader('read_pos') | |||
class POSDataSetLoader(DataSetLoader): | |||
"""Dataset Loader for POS Tag datasets. | |||
@@ -168,6 +173,7 @@ class POSDataSetLoader(DataSetLoader): | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_tokenize') | |||
class TokenizeDataSetLoader(DataSetLoader): | |||
""" | |||
@@ -227,6 +233,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2seq_dataset(data) | |||
@DataSet.set_reader('read_class') | |||
class ClassDataSetLoader(DataSetLoader): | |||
"""Loader for classification data sets""" | |||
@@ -265,6 +272,7 @@ class ClassDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
return convert_seq2tag_dataset(data) | |||
@DataSet.set_reader('read_conll') | |||
class ConllLoader(DataSetLoader): | |||
"""loader for conll format files""" | |||
@@ -306,6 +314,7 @@ class ConllLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_lm') | |||
class LMDataSetLoader(DataSetLoader): | |||
"""Language Model Dataset Loader | |||
@@ -342,6 +351,7 @@ class LMDataSetLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
@DataSet.set_reader('read_people_daily') | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
""" | |||
@@ -394,3 +404,72 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
def convert(self, data): | |||
pass | |||
class SNLIDataSetLoader(DataSetLoader): | |||
"""A data set loader for SNLI data set. | |||
""" | |||
def __init__(self): | |||
super(SNLIDataSetLoader, self).__init__() | |||
def load(self, path_list): | |||
""" | |||
:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
:return: data_set: A DataSet object. | |||
""" | |||
assert len(path_list) == 3 | |||
line_set = [] | |||
for file in path_list: | |||
if not os.path.exists(file): | |||
raise FileNotFoundError("file {} NOT found".format(file)) | |||
with open(file, 'r', encoding='utf-8') as f: | |||
lines = f.readlines() | |||
line_set.append(lines) | |||
premise_lines, hypothesis_lines, label_lines = line_set | |||
assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) | |||
data_set = [] | |||
for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): | |||
p = premise.strip().split() | |||
h = hypothesis.strip().split() | |||
l = label.strip() | |||
data_set.append([p, h, l]) | |||
return self.convert(data_set) | |||
def convert(self, data): | |||
"""Convert a 3D list to a DataSet object. | |||
:param data: A 3D tensor. | |||
[ | |||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
... | |||
] | |||
:return: data_set: A DataSet object. | |||
""" | |||
data_set = DataSet() | |||
for example in data: | |||
p, h, l = example | |||
# list, list, str | |||
x1 = TextField(p, is_target=False) | |||
x2 = TextField(h, is_target=False) | |||
x1_len = TextField([1] * len(p), is_target=False) | |||
x2_len = TextField([1] * len(h), is_target=False) | |||
y = LabelField(l, is_target=True) | |||
instance = Instance() | |||
instance.add_field("premise", x1) | |||
instance.add_field("hypothesis", x2) | |||
instance.add_field("premise_len", x1_len) | |||
instance.add_field("hypothesis_len", x2_len) | |||
instance.add_field("truth", y) | |||
data_set.append(instance) | |||
return data_set |
@@ -6,11 +6,12 @@ import torch | |||
from fastNLP.loader.base_loader import BaseLoader | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class EmbedLoader(BaseLoader): | |||
"""docstring for EmbedLoader""" | |||
def __init__(self, data_path): | |||
super(EmbedLoader, self).__init__(data_path) | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@staticmethod | |||
def _load_glove(emb_file): | |||
@@ -55,15 +56,15 @@ class EmbedLoader(BaseLoader): | |||
:param emb_type: str, the pre-trained embedding format, support glove now | |||
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:param emb_pkl: str, the embedding pickle file. | |||
:return embedding_np: numpy array of shape (len(word_dict), emb_dim) | |||
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||
vocab: input vocab or vocab built by pre-train | |||
TODO: fragile code | |||
""" | |||
# If the embedding pickle exists, load it and return. | |||
if os.path.exists(emb_pkl): | |||
with open(emb_pkl, "rb") as f: | |||
embedding_np, vocab = _pickle.load(f) | |||
return embedding_np, vocab | |||
embedding_tensor, vocab = _pickle.load(f) | |||
return embedding_tensor, vocab | |||
# Otherwise, load the pre-trained embedding. | |||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||
if vocab is None: | |||
@@ -71,14 +72,14 @@ class EmbedLoader(BaseLoader): | |||
vocab = Vocabulary() | |||
for w in pretrain.keys(): | |||
vocab.update(w) | |||
embedding_np = torch.randn(len(vocab), emb_dim) | |||
embedding_tensor = torch.randn(len(vocab), emb_dim) | |||
for w, v in pretrain.items(): | |||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||
if vocab.has_word(w): | |||
embedding_np[vocab[w]] = v | |||
embedding_tensor[vocab[w]] = v | |||
# save and return the result | |||
with open(emb_pkl, "wb") as f: | |||
_pickle.dump((embedding_np, vocab), f) | |||
return embedding_np, vocab | |||
_pickle.dump((embedding_tensor, vocab), f) | |||
return embedding_tensor, vocab |
@@ -0,0 +1,161 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder as Decoder, encoder as Encoder | |||
my_inf = 10e12 | |||
class SNLI(BaseModel): | |||
""" | |||
PyTorch Network for SNLI. | |||
""" | |||
def __init__(self, args, init_embedding=None): | |||
super(SNLI, self).__init__() | |||
self.vocab_size = args["vocab_size"] | |||
self.embed_dim = args["embed_dim"] | |||
self.hidden_size = args["hidden_size"] | |||
self.batch_first = args["batch_first"] | |||
self.dropout = args["dropout"] | |||
self.n_labels = args["num_classes"] | |||
self.gpu = args["gpu"] and torch.cuda.is_available() | |||
self.embedding = Encoder.embedding.Embedding(self.vocab_size, self.embed_dim, init_emb=init_embedding, | |||
dropout=self.dropout) | |||
self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) | |||
self.encoder = Encoder.LSTM( | |||
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, | |||
batch_first=self.batch_first, bidirectional=True | |||
) | |||
self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size) | |||
self.decoder = Encoder.LSTM( | |||
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, | |||
batch_first=self.batch_first, bidirectional=True | |||
) | |||
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh') | |||
def forward(self, premise, hypothesis, premise_len, hypothesis_len): | |||
""" Forward function | |||
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL), hidden size(H)]. | |||
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL), H]. | |||
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. | |||
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. | |||
:return: prediction: A Tensor of classification result: [B, n_labels(N)]. | |||
""" | |||
premise0 = self.embedding_layer(self.embedding(premise)) | |||
hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) | |||
_BP, _PSL, _HP = premise0.size() | |||
_BH, _HSL, _HH = hypothesis0.size() | |||
_BPL, _PLL = premise_len.size() | |||
_HPL, _HLL = hypothesis_len.size() | |||
assert _BP == _BH and _BPL == _HPL and _BP == _BPL | |||
assert _HP == _HH | |||
assert _PSL == _PLL and _HSL == _HLL | |||
B, PL, H = premise0.size() | |||
B, HL, H = hypothesis0.size() | |||
# a0, (ah0, ac0) = self.encoder(premise) # a0: [B, PL, H * 2], ah0: [2, B, H] | |||
# b0, (bh0, bc0) = self.encoder(hypothesis) # b0: [B, HL, H * 2] | |||
a0 = self.encoder(premise0) # a0: [B, PL, H * 2] | |||
b0 = self.encoder(hypothesis0) # b0: [B, HL, H * 2] | |||
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] | |||
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] | |||
ai, bi = self.calc_bi_attention(a, b, premise_len, hypothesis_len) | |||
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] | |||
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] | |||
f_ma = self.inference_layer(ma) | |||
f_mb = self.inference_layer(mb) | |||
vat = self.decoder(f_ma) | |||
vbt = self.decoder(f_mb) | |||
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] | |||
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] | |||
# va_ave = torch.mean(va, dim=1) # va_ave: [B, H] | |||
# va_max, va_arg_max = torch.max(va, dim=1) # va_max: [B, H] | |||
# vb_ave = torch.mean(vb, dim=1) # vb_ave: [B, H] | |||
# vb_max, vb_arg_max = torch.max(vb, dim=1) # vb_max: [B, H] | |||
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] | |||
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] | |||
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] | |||
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] | |||
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] | |||
# v_mlp = F.tanh(self.mlp_layer1(v)) # v_mlp: [B, H] | |||
# prediction = self.mlp_layer2(v_mlp) # prediction: [B, N] | |||
prediction = F.tanh(self.output(v)) # prediction: [B, N] | |||
return prediction | |||
@staticmethod | |||
def calc_bi_attention(in_x1, in_x2, x1_len, x2_len): | |||
# in_x1: [batch_size, x1_seq_len, hidden_size] | |||
# in_x2: [batch_size, x2_seq_len, hidden_size] | |||
# x1_len: [batch_size, x1_seq_len] | |||
# x2_len: [batch_size, x2_seq_len] | |||
assert in_x1.size()[0] == in_x2.size()[0] | |||
assert in_x1.size()[2] == in_x2.size()[2] | |||
# The batch size and hidden size must be equal. | |||
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] | |||
# The seq len in in_x and x_len must be equal. | |||
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] | |||
batch_size = in_x1.size()[0] | |||
x1_max_len = in_x1.size()[1] | |||
x2_max_len = in_x2.size()[1] | |||
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] | |||
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] | |||
a_mask = x1_len.le(0.5).float() * -my_inf # [batch_size, x1_seq_len] | |||
a_mask = a_mask.view(batch_size, x1_max_len, -1) | |||
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] | |||
b_mask = x2_len.le(0.5).float() * -my_inf | |||
b_mask = b_mask.view(batch_size, -1, x2_max_len) | |||
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] | |||
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] | |||
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] | |||
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] | |||
attention_b_t = torch.transpose(attention_b, 1, 2) | |||
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] | |||
return out_x1, out_x2 | |||
@staticmethod | |||
def mean_pooling(tensor, mask, dim=0): | |||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||
return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) | |||
@staticmethod | |||
def max_pooling(tensor, mask, dim=0): | |||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||
masks = masks.expand(-1, -1, tensor.size(2)).float() | |||
return torch.max(tensor + masks.le(0.5).float() * -my_inf, dim=dim) |
@@ -1,12 +1,15 @@ | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.modules.utils import initial_parameter | |||
class MLP(nn.Module): | |||
def __init__(self, size_layer, activation='relu' , initial_method = None): | |||
def __init__(self, size_layer, activation='relu', initial_method=None): | |||
"""Multilayer Perceptrons as a decoder | |||
:param size_layer: list of int, define the size of MLP layers | |||
:param activation: str or function, the activation function for hidden layers | |||
:param size_layer: list of int, define the size of MLP layers. | |||
:param activation: str or function, the activation function for hidden layers. | |||
:param initial_method: str, the name of init method. | |||
.. note:: | |||
There is no activation function applying on output layer. | |||
@@ -23,7 +26,7 @@ class MLP(nn.Module): | |||
actives = { | |||
'relu': nn.ReLU(), | |||
'tanh': nn.Tanh() | |||
'tanh': nn.Tanh(), | |||
} | |||
if activation in actives: | |||
self.hidden_active = actives[activation] | |||
@@ -31,7 +34,7 @@ class MLP(nn.Module): | |||
self.hidden_active = activation | |||
else: | |||
raise ValueError("should set activation correctly: {}".format(activation)) | |||
initial_parameter(self, initial_method ) | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
for layer in self.hiddens: | |||
@@ -40,13 +43,11 @@ class MLP(nn.Module): | |||
return x | |||
if __name__ == '__main__': | |||
net1 = MLP([5,10,5]) | |||
net2 = MLP([5,10,5], 'tanh') | |||
net1 = MLP([5, 10, 5]) | |||
net2 = MLP([5, 10, 5], 'tanh') | |||
for net in [net1, net2]: | |||
x = torch.randn(5, 5) | |||
y = net(x) | |||
print(x) | |||
print(y) | |||
@@ -1,6 +1,8 @@ | |||
import torch.nn as nn | |||
from fastNLP.modules.utils import initial_parameter | |||
class Linear(nn.Module): | |||
""" | |||
Linear module | |||
@@ -12,10 +14,11 @@ class Linear(nn.Module): | |||
bidirectional : If True, becomes a bidirectional RNN | |||
""" | |||
def __init__(self, input_size, output_size, bias=True,initial_method = None ): | |||
def __init__(self, input_size, output_size, bias=True, initial_method=None): | |||
super(Linear, self).__init__() | |||
self.linear = nn.Linear(input_size, output_size, bias) | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
x = self.linear(x) | |||
return x |
@@ -14,16 +14,23 @@ class LSTM(nn.Module): | |||
bidirectional : If True, becomes a bidirectional RNN. Default: False. | |||
""" | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, | |||
initial_method=None): | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True, initial_method=None, get_hidden=False): | |||
super(LSTM, self).__init__() | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
dropout=dropout, bidirectional=bidirectional) | |||
self.get_hidden = get_hidden | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
x, _ = self.lstm(x) | |||
return x | |||
def forward(self, x, h0=None, c0=None): | |||
if h0 is not None and c0 is not None: | |||
x, (ht, ct) = self.lstm(x, (h0, c0)) | |||
else: | |||
x, (ht, ct) = self.lstm(x) | |||
if self.get_hidden: | |||
return x, (ht, ct) | |||
else: | |||
return x | |||
if __name__ == "__main__": | |||
@@ -45,3 +45,28 @@ use_cuda = true | |||
learn_rate = 1e-3 | |||
momentum = 0.9 | |||
model_name = "class_model.pkl" | |||
[snli_trainer] | |||
epochs = 5 | |||
batch_size = 32 | |||
validate = true | |||
save_best_dev = true | |||
use_cuda = true | |||
learn_rate = 1e-4 | |||
loss = "cross_entropy" | |||
print_every_step = 1000 | |||
[snli_tester] | |||
batch_size = 512 | |||
use_cuda = true | |||
[snli_model] | |||
model_name = "snli_model.pkl" | |||
embed_dim = 300 | |||
hidden_size = 300 | |||
batch_first = true | |||
dropout = 0.5 | |||
gpu = true | |||
embed_file = "./../data_for_tests/glove.840B.300d.txt" | |||
embed_pkl = "./snli/embed.pkl" | |||
examples = 0 |