Browse Source

add init parameters in modules; optimizer Preprocessor; add metrics.py

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
92be0722dc
13 changed files with 148 additions and 158 deletions
  1. +0
    -3
      fastNLP/action/inference.py
  2. +7
    -0
      fastNLP/action/metrics.py
  3. +14
    -13
      fastNLP/action/optimizor.py
  4. +5
    -5
      fastNLP/action/tester.py
  5. +13
    -5
      fastNLP/action/trainer.py
  6. +70
    -98
      fastNLP/loader/preprocess.py
  7. +1
    -1
      fastNLP/modules/aggregation/kmax_pool.py
  8. +0
    -9
      fastNLP/modules/aggregation/linear_attention.py
  9. +7
    -6
      fastNLP/modules/aggregation/self_attention.py
  10. +20
    -12
      fastNLP/modules/encoder/char_embedding.py
  11. +3
    -0
      fastNLP/modules/encoder/conv.py
  12. +3
    -2
      fastNLP/modules/encoder/embedding.py
  13. +5
    -4
      fastNLP/modules/encoder/lstm.py

+ 0
- 3
fastNLP/action/inference.py View File

@@ -3,9 +3,6 @@ class Inference(object):
This is an interface focusing on predicting output based on trained models. This is an interface focusing on predicting output based on trained models.
It does not care about evaluations of the model. It does not care about evaluations of the model.


Possible improvements:
- use batch to make use of GPU

""" """


def __init__(self): def __init__(self):


+ 7
- 0
fastNLP/action/metrics.py View File

@@ -0,0 +1,7 @@
"""
To do:
参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics
建议是每种metric写成一个函数 (由Tester的evaluate函数调用)
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置

"""

+ 14
- 13
fastNLP/action/optimizor.py View File

@@ -1,15 +1,16 @@
from torch import optim from torch import optim


def get_torch_optimizor(params, alg_name='sgd', **args):
'''
construct pytorch optimizor by algorithm's name
optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc

def get_torch_optimizer(params, alg_name='sgd', **args):
"""
construct PyTorch optimizer by algorithm's name
optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc


usage: usage:
optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01)
optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01)

"""


'''
name = alg_name.lower() name = alg_name.lower()
if name == 'adadelta': if name == 'adadelta':
return optim.Adadelta(params, **args) return optim.Adadelta(params, **args)
@@ -28,22 +29,22 @@ def get_torch_optimizor(params, alg_name='sgd', **args):
elif name == 'rprop': elif name == 'rprop':
return optim.Rprop(params, **args) return optim.Rprop(params, **args)
elif name == 'sgd': elif name == 'sgd':
#SGD's parameter lr is required
# SGD's parameter lr is required
if 'lr' not in args: if 'lr' not in args:
args['lr'] = 0.01 args['lr'] = 0.01
return optim.SGD(params, **args) return optim.SGD(params, **args)
elif name == 'sparseadam': elif name == 'sparseadam':
return optim.SparseAdam(params, **args) return optim.SparseAdam(params, **args)
else: else:
raise TypeError('no such optimizor named {}'.format(alg_name))
raise TypeError('no such optimizer named {}'.format(alg_name))




# example usage
if __name__ == '__main__': if __name__ == '__main__':
from torch.nn.modules import Linear from torch.nn.modules import Linear

net = Linear(2, 5) net = Linear(2, 5)


test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3)
test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3)
print(test1) print(test1)
test2 = get_torch_optimizor(net.parameters(), 'SGD')
print(test2)
test2 = get_torch_optimizer(net.parameters(), 'SGD')
print(test2)

+ 5
- 5
fastNLP/action/tester.py View File

@@ -1,8 +1,8 @@
import _pickle import _pickle
import os


import numpy as np import numpy as np
import torch import torch
import os


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
@@ -108,7 +108,7 @@ class BaseTester(Action):
raise NotImplementedError raise NotImplementedError


@property @property
def matrices(self):
def metrics(self):
raise NotImplementedError raise NotImplementedError


def mode(self, model, test=True): def mode(self, model, test=True):
@@ -163,7 +163,7 @@ class POSTester(BaseTester):
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
return [loss.data, accuracy] return [loss.data, accuracy]


def matrices(self):
def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history]) batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history]) batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy return batch_loss, batch_accuracy
@@ -173,7 +173,7 @@ class POSTester(BaseTester):
This is called by Trainer to print evaluation on dev set. This is called by Trainer to print evaluation on dev set.
:return print_str: str :return print_str: str
""" """
loss, accuracy = self.matrices()
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)




@@ -309,7 +309,7 @@ class ClassTester(BaseTester):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1) y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true] return [y_prob, y_true]


def matrices(self):
def metrics(self):
"""Compute accuracy.""" """Compute accuracy."""
y_prob, y_true = zip(*self.eval_history) y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0) y_prob = torch.cat(y_prob, dim=0)


+ 13
- 5
fastNLP/action/trainer.py View File

@@ -181,7 +181,7 @@ class BaseTrainer(Action):
""" """
raise NotImplementedError raise NotImplementedError


def batchify(self, data):
def batchify(self, data, output_length=True):
""" """
1. Perform batching from data and produce a batch of training data. 1. Perform batching from data and produce a batch of training data.
2. Add padding. 2. Add padding.
@@ -194,13 +194,18 @@ class BaseTrainer(Action):
] ]
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
""" """
indices = next(self.iterator) indices = next(self.iterator)
batch = [data[idx] for idx in indices] batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch] batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch] batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)
return batch_x, batch_y
batch_x_pad = self.pad(batch_x)
if output_length:
seq_len = [len(x) for x in batch_x]
return batch_x_pad, batch_y, seq_len
else:
return batch_x_pad, batch_y


@staticmethod @staticmethod
def pad(batch, fill=0): def pad(batch, fill=0):
@@ -245,7 +250,10 @@ class ToyTrainer(BaseTrainer):
return data_train, data_dev, 0, 1 return data_train, data_dev, 0, 1


def mode(self, test=False): def mode(self, test=False):
self.model.mode(test)
if test:
self.model.eval()
else:
self.model.train()


def data_forward(self, network, x): def data_forward(self, network, x):
return network(x) return network(x)
@@ -333,7 +341,7 @@ class POSTrainer(BaseTrainer):
return loss return loss


def best_eval_result(self, validator): def best_eval_result(self, validator):
loss, accuracy = validator.matrices()
loss, accuracy = validator.metrics()
if accuracy > self.best_accuracy: if accuracy > self.best_accuracy:
self.best_accuracy = accuracy self.best_accuracy = accuracy
return True return True


+ 70
- 98
fastNLP/loader/preprocess.py View File

@@ -7,6 +7,10 @@ DEFAULT_RESERVED_LABEL = ['<reserved-2>',
'<reserved-3>', '<reserved-3>',
'<reserved-4>'] # dict index = 2~4 '<reserved-4>'] # dict index = 2~4
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
# the first vocab in dict with the index = 5 # the first vocab in dict with the index = 5
@@ -41,52 +45,79 @@ class POSPreprocess(BasePreprocess):
to label5. to label5.
""" """
def __init__(self, data, pickle_path):
def __init__(self, data, pickle_path, train_dev_split=0):
"""
Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
:param data:
:param pickle_path:
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
To do:
1. use @contextmanager to handle pickle dumps and loads
"""
super(POSPreprocess, self).__init__(data, pickle_path) super(POSPreprocess, self).__init__(data, pickle_path)
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = None
self.data = data
self.pickle_path = pickle_path self.pickle_path = pickle_path
self.build_dict(data)
if not self.pickle_exist("word2id.pkl"):
self.word_dict.update(self.word2id(data))
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
if self.pickle_exist("word2id.pkl"):
# load word2index because the construction of the following objects needs it
with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f:
self.word2index = _pickle.load(f)
else:
self.word2index, self.label2index = self.build_dict(data)
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
_pickle.dump(self.word2index, f)
self.vocab_size = self.id2word()
self.class2id()
self.num_classes = self.id2class()
self.embedding()
self.data_train()
self.data_dev()
self.data_test()
if self.pickle_exist("class2id.pkl"):
with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f:
self.label2index = _pickle.load(f)
else:
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
_pickle.dump(self.label2index, f)
if not self.pickle_exist("id2word.pkl"):
index2word = self.build_reverse_dict(self.word2index)
with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f:
_pickle.dump(index2word, f)
if not self.pickle_exist("id2class.pkl"):
index2label = self.build_reverse_dict(self.label2index)
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
_pickle.dump(index2label, f)
if not self.pickle_exist("data_train.pkl"):
data_train = self.to_index(data)
if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"):
data_dev = data_train[: int(len(data_train) * train_dev_split)]
with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f:
_pickle.dump(data_dev, f)
with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f:
_pickle.dump(data_train, f)
def build_dict(self, data): def build_dict(self, data):
""" """
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: list of list [word, label] :param data: list of list [word, label]
:return word2index: dict of (str, int)
label2index: dict of (str, int)
""" """
self.label_dict = {}
label2index = {}
word2index = DEFAULT_WORD_TO_INDEX
for line in data: for line in data:
line = line.strip() line = line.strip()
if len(line) <= 1: if len(line) <= 1:
continue continue
tokens = line.split('\t') tokens = line.split('\t')
if tokens[0] not in self.word_dict:
if tokens[0] not in word2index:
# add (word, index) into the dict # add (word, index) into the dict
self.word_dict[tokens[0]] = len(self.word_dict)
word2index[tokens[0]] = len(word2index)
# for label in tokens[1: ]: # for label in tokens[1: ]:
if tokens[1] not in self.label_dict:
self.label_dict[tokens[1]] = len(self.label_dict)
if tokens[1] not in label2index:
label2index[tokens[1]] = len(label2index)
return word2index, label2index
def pickle_exist(self, pickle_name): def pickle_exist(self, pickle_name):
""" """
@@ -101,90 +132,31 @@ class POSPreprocess(BasePreprocess):
else: else:
return False return False
def word2id(self):
if self.pickle_exist("word2id.pkl"):
return
# nothing will be done if word2id.pkl exists
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
def id2word(self):
if self.pickle_exist("id2word.pkl"):
file_name = os.path.join(self.pickle_path, "id2word.pkl")
id2word_dict = _pickle.load(open(file_name, "rb"))
return len(id2word_dict)
# nothing will be done if id2word.pkl exists
id2word_dict = {}
for word in self.word_dict:
id2word_dict[self.word_dict[word]] = word
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2word_dict, f)
return len(id2word_dict)
def class2id(self):
if self.pickle_exist("class2id.pkl"):
return
# nothing will be done if class2id.pkl exists
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.label_dict, f)
def id2class(self):
if self.pickle_exist("id2class.pkl"):
file_name = os.path.join(self.pickle_path, "id2class.pkl")
id2class_dict = _pickle.load(open(file_name, "rb"))
return len(id2class_dict)
# nothing will be done if id2class.pkl exists
id2class_dict = {}
for label in self.label_dict:
id2class_dict[self.label_dict[label]] = label
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2class_dict, f)
return len(id2class_dict)
def embedding(self):
if self.pickle_exist("embedding.pkl"):
return
# nothing will be done if embedding.pkl exists
def data_train(self):
if self.pickle_exist("data_train.pkl"):
return
# nothing will be done if data_train.pkl exists
def build_reverse_dict(self, word_dict):
id2word = {word_dict[w]: w for w in word_dict}
return id2word
def to_index(self, data):
"""
Convert word strings and label strings into indices.
:param data: list of str. Each string is a line, described above.
:return data_index: list of tuple (word index, label index)
"""
data_train = [] data_train = []
sentence = [] sentence = []
for w in self.data:
for w in data:
w = w.strip() w = w.strip()
if len(w) <= 1: if len(w) <= 1:
wid = [] wid = []
lid = [] lid = []
for i in range(len(sentence)): for i in range(len(sentence)):
# if sentence[i][0]=="":
# print("")
wid.append(self.word_dict[sentence[i][0]])
lid.append(self.label_dict[sentence[i][1]])
wid.append(self.word2index[sentence[i][0]])
lid.append(self.label2index[sentence[i][1]])
data_train.append((wid, lid)) data_train.append((wid, lid))
sentence = [] sentence = []
continue continue
sentence.append(w.split('\t')) sentence.append(w.split('\t'))
file_name = os.path.join(self.pickle_path, "data_train.pkl")
with open(file_name, "wb") as f:
_pickle.dump(data_train, f)
def data_dev(self):
pass
def data_test(self):
pass
return data_train
class ClassPreprocess(BasePreprocess): class ClassPreprocess(BasePreprocess):


+ 1
- 1
fastNLP/modules/aggregation/kmax_pool.py View File

@@ -9,7 +9,7 @@ import torch.nn as nn
class KMaxPool(nn.Module): class KMaxPool(nn.Module):
"""K max-pooling module.""" """K max-pooling module."""


def __init__(self, k):
def __init__(self, k=1):
super(KMaxPool, self).__init__() super(KMaxPool, self).__init__()
self.k = k self.k = k




+ 0
- 9
fastNLP/modules/aggregation/linear_attention.py View File

@@ -1,9 +0,0 @@
from fastNLP.modules.aggregation.attention import Attention


class LinearAttention(Attention):
def __init__(self, normalize=False):
super(LinearAttention, self).__init__(normalize)

def _atten_forward(self, query, memory):
raise NotImplementedError

+ 7
- 6
fastNLP/modules/aggregation/self_attention.py View File

@@ -8,14 +8,15 @@ class SelfAttention(nn.Module):
Self Attention Module. Self Attention Module.


Args: Args:
input_size : the size for the input vector
d_a : the width of weight matrix
r : the number of encoded vectors
input_size: int, the size for the input vector
dim: int, the width of weight matrix.
num_vec: int, the number of encoded vectors
""" """
def __init__(self, input_size, d_a, r):

def __init__(self, input_size, dim=10, num_vec=10):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.W_s1 = nn.Parameter(torch.randn(d_a, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(r, d_a), requires_grad=True)
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
self.softmax = nn.Softmax(dim=2) self.softmax = nn.Softmax(dim=2)
self.tanh = nn.Tanh() self.tanh = nn.Tanh()




+ 20
- 12
fastNLP/modules/encoder/char_embedding.py View File

@@ -5,13 +5,15 @@ from torch import nn


class ConvCharEmbedding(nn.Module): class ConvCharEmbedding(nn.Module):


def __init__(self, char_emb_size, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
""" """
Character Level Word Embedding Character Level Word Embedding
:param char_emb_size: the size of character level embedding,
:param char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50. say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param feature_maps: table of feature maps (for each kernel width)
:param kernels: table of kernel widths
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param kernels: tuple of int. The width of each kernel.
""" """
super(ConvCharEmbedding, self).__init__() super(ConvCharEmbedding, self).__init__()
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
@@ -23,29 +25,35 @@ class ConvCharEmbedding(nn.Module):
:param x: [batch_size * sent_length, word_length, char_emb_size] :param x: [batch_size * sent_length, word_length, char_emb_size]
:return: [batch_size * sent_length, sum(feature_maps), 1] :return: [batch_size * sent_length, sum(feature_maps), 1]
""" """
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) # [batch_size*sent_length, channel, width, height]
x = x.transpose(2, 3) # [batch_size*sent_length, channel, height, width]
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height]
x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width]
return self.convolute(x).unsqueeze(2) return self.convolute(x).unsqueeze(2)


def convolute(self, x): def convolute(self, x):
feats = [] feats = []
for conv in self.convs: for conv in self.convs:
y = conv(x) # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = conv(x)
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y) y = F.tanh(y)
y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]]
y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]]
feats.append(y) feats.append(y)
return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)] return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)]




class LSTMCharEmbedding(nn.Module): class LSTMCharEmbedding(nn.Module):
""" """
Character Level Word Embedding with LSTM
:param char_emb_size: the size of character level embedding,
Character Level Word Embedding with LSTM with a single layer.
:param char_emb_size: int, the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50. say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
""" """


def __init__(self, char_emb_size, hidden_size=None):
def __init__(self, char_emb_size=50, hidden_size=None):
super(LSTMCharEmbedding, self).__init__() super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size self.hidden_size = char_emb_size if hidden_size is None else hidden_size




+ 3
- 0
fastNLP/modules/encoder/conv.py View File

@@ -2,12 +2,14 @@
# encoding: utf-8 # encoding: utf-8


import torch.nn as nn import torch.nn as nn
from torch.nn.init import xavier_uniform
# import torch.nn.functional as F # import torch.nn.functional as F




class Conv(nn.Module): class Conv(nn.Module):
""" """
Basic 1-d convolution module. Basic 1-d convolution module.
initialize with xavier_uniform
""" """


def __init__(self, in_channels, out_channels, kernel_size, def __init__(self, in_channels, out_channels, kernel_size,
@@ -23,6 +25,7 @@ class Conv(nn.Module):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias) bias=bias)
xavier_uniform(self.conv.weight)


def forward(self, x): def forward(self, x):
return self.conv(x) # [N,C,L] return self.conv(x) # [N,C,L]

+ 3
- 2
fastNLP/modules/encoder/embedding.py View File

@@ -7,12 +7,13 @@ class Lookuptable(nn.Module):


Args: Args:
nums : the size of the lookup table nums : the size of the lookup table
dims : the size of each vector
dims : the size of each vector. Default: 50.
padding_idx : pads the tensor with zeros whenever it encounters this index padding_idx : pads the tensor with zeros whenever it encounters this index
sparse : If True, gradient matrix will be a sparse tensor. In this case, sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
""" """
def __init__(self, nums, dims, padding_idx=0, sparse=False):

def __init__(self, nums, dims=50, padding_idx=0, sparse=False):
super(Lookuptable, self).__init__() super(Lookuptable, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)


+ 5
- 4
fastNLP/modules/encoder/lstm.py View File

@@ -8,11 +8,12 @@ class Lstm(nn.Module):
Args: Args:
input_size : input size input_size : input size
hidden_size : hidden size hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
""" """
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional):

def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5, bidirectional=False):
super(Lstm, self).__init__() super(Lstm, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional) dropout=dropout, bidirectional=bidirectional)


Loading…
Cancel
Save