Browse Source

add init parameters in modules; optimizer Preprocessor; add metrics.py

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
92be0722dc
13 changed files with 148 additions and 158 deletions
  1. +0
    -3
      fastNLP/action/inference.py
  2. +7
    -0
      fastNLP/action/metrics.py
  3. +14
    -13
      fastNLP/action/optimizor.py
  4. +5
    -5
      fastNLP/action/tester.py
  5. +13
    -5
      fastNLP/action/trainer.py
  6. +70
    -98
      fastNLP/loader/preprocess.py
  7. +1
    -1
      fastNLP/modules/aggregation/kmax_pool.py
  8. +0
    -9
      fastNLP/modules/aggregation/linear_attention.py
  9. +7
    -6
      fastNLP/modules/aggregation/self_attention.py
  10. +20
    -12
      fastNLP/modules/encoder/char_embedding.py
  11. +3
    -0
      fastNLP/modules/encoder/conv.py
  12. +3
    -2
      fastNLP/modules/encoder/embedding.py
  13. +5
    -4
      fastNLP/modules/encoder/lstm.py

+ 0
- 3
fastNLP/action/inference.py View File

@@ -3,9 +3,6 @@ class Inference(object):
This is an interface focusing on predicting output based on trained models.
It does not care about evaluations of the model.

Possible improvements:
- use batch to make use of GPU

"""

def __init__(self):


+ 7
- 0
fastNLP/action/metrics.py View File

@@ -0,0 +1,7 @@
"""
To do:
参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics
建议是每种metric写成一个函数 (由Tester的evaluate函数调用)
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置

"""

+ 14
- 13
fastNLP/action/optimizor.py View File

@@ -1,15 +1,16 @@
from torch import optim

def get_torch_optimizor(params, alg_name='sgd', **args):
'''
construct pytorch optimizor by algorithm's name
optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc

def get_torch_optimizer(params, alg_name='sgd', **args):
"""
construct PyTorch optimizer by algorithm's name
optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc

usage:
optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01)
optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01)

"""

'''
name = alg_name.lower()
if name == 'adadelta':
return optim.Adadelta(params, **args)
@@ -28,22 +29,22 @@ def get_torch_optimizor(params, alg_name='sgd', **args):
elif name == 'rprop':
return optim.Rprop(params, **args)
elif name == 'sgd':
#SGD's parameter lr is required
# SGD's parameter lr is required
if 'lr' not in args:
args['lr'] = 0.01
return optim.SGD(params, **args)
elif name == 'sparseadam':
return optim.SparseAdam(params, **args)
else:
raise TypeError('no such optimizor named {}'.format(alg_name))
raise TypeError('no such optimizer named {}'.format(alg_name))


# example usage
if __name__ == '__main__':
from torch.nn.modules import Linear

net = Linear(2, 5)

test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3)
test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3)
print(test1)
test2 = get_torch_optimizor(net.parameters(), 'SGD')
print(test2)
test2 = get_torch_optimizer(net.parameters(), 'SGD')
print(test2)

+ 5
- 5
fastNLP/action/tester.py View File

@@ -1,8 +1,8 @@
import _pickle
import os

import numpy as np
import torch
import os

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
@@ -108,7 +108,7 @@ class BaseTester(Action):
raise NotImplementedError

@property
def matrices(self):
def metrics(self):
raise NotImplementedError

def mode(self, model, test=True):
@@ -163,7 +163,7 @@ class POSTester(BaseTester):
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
return [loss.data, accuracy]

def matrices(self):
def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy
@@ -173,7 +173,7 @@ class POSTester(BaseTester):
This is called by Trainer to print evaluation on dev set.
:return print_str: str
"""
loss, accuracy = self.matrices()
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)


@@ -309,7 +309,7 @@ class ClassTester(BaseTester):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

def matrices(self):
def metrics(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)


+ 13
- 5
fastNLP/action/trainer.py View File

@@ -181,7 +181,7 @@ class BaseTrainer(Action):
"""
raise NotImplementedError

def batchify(self, data):
def batchify(self, data, output_length=True):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
@@ -194,13 +194,18 @@ class BaseTrainer(Action):
]
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
"""
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)
return batch_x, batch_y
batch_x_pad = self.pad(batch_x)
if output_length:
seq_len = [len(x) for x in batch_x]
return batch_x_pad, batch_y, seq_len
else:
return batch_x_pad, batch_y

@staticmethod
def pad(batch, fill=0):
@@ -245,7 +250,10 @@ class ToyTrainer(BaseTrainer):
return data_train, data_dev, 0, 1

def mode(self, test=False):
self.model.mode(test)
if test:
self.model.eval()
else:
self.model.train()

def data_forward(self, network, x):
return network(x)
@@ -333,7 +341,7 @@ class POSTrainer(BaseTrainer):
return loss

def best_eval_result(self, validator):
loss, accuracy = validator.matrices()
loss, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True


+ 70
- 98
fastNLP/loader/preprocess.py View File

@@ -7,6 +7,10 @@ DEFAULT_RESERVED_LABEL = ['<reserved-2>',
'<reserved-3>',
'<reserved-4>'] # dict index = 2~4
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
# the first vocab in dict with the index = 5
@@ -41,52 +45,79 @@ class POSPreprocess(BasePreprocess):
to label5.
"""
def __init__(self, data, pickle_path):
def __init__(self, data, pickle_path, train_dev_split=0):
"""
Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
:param data:
:param pickle_path:
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0.
To do:
1. use @contextmanager to handle pickle dumps and loads
"""
super(POSPreprocess, self).__init__(data, pickle_path)
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = None
self.data = data
self.pickle_path = pickle_path
self.build_dict(data)
if not self.pickle_exist("word2id.pkl"):
self.word_dict.update(self.word2id(data))
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
if self.pickle_exist("word2id.pkl"):
# load word2index because the construction of the following objects needs it
with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f:
self.word2index = _pickle.load(f)
else:
self.word2index, self.label2index = self.build_dict(data)
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
_pickle.dump(self.word2index, f)
self.vocab_size = self.id2word()
self.class2id()
self.num_classes = self.id2class()
self.embedding()
self.data_train()
self.data_dev()
self.data_test()
if self.pickle_exist("class2id.pkl"):
with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f:
self.label2index = _pickle.load(f)
else:
with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f:
_pickle.dump(self.label2index, f)
if not self.pickle_exist("id2word.pkl"):
index2word = self.build_reverse_dict(self.word2index)
with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f:
_pickle.dump(index2word, f)
if not self.pickle_exist("id2class.pkl"):
index2label = self.build_reverse_dict(self.label2index)
with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f:
_pickle.dump(index2label, f)
if not self.pickle_exist("data_train.pkl"):
data_train = self.to_index(data)
if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"):
data_dev = data_train[: int(len(data_train) * train_dev_split)]
with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f:
_pickle.dump(data_dev, f)
with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f:
_pickle.dump(data_train, f)
def build_dict(self, data):
"""
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
:param data: list of list [word, label]
:return word2index: dict of (str, int)
label2index: dict of (str, int)
"""
self.label_dict = {}
label2index = {}
word2index = DEFAULT_WORD_TO_INDEX
for line in data:
line = line.strip()
if len(line) <= 1:
continue
tokens = line.split('\t')
if tokens[0] not in self.word_dict:
if tokens[0] not in word2index:
# add (word, index) into the dict
self.word_dict[tokens[0]] = len(self.word_dict)
word2index[tokens[0]] = len(word2index)
# for label in tokens[1: ]:
if tokens[1] not in self.label_dict:
self.label_dict[tokens[1]] = len(self.label_dict)
if tokens[1] not in label2index:
label2index[tokens[1]] = len(label2index)
return word2index, label2index
def pickle_exist(self, pickle_name):
"""
@@ -101,90 +132,31 @@ class POSPreprocess(BasePreprocess):
else:
return False
def word2id(self):
if self.pickle_exist("word2id.pkl"):
return
# nothing will be done if word2id.pkl exists
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
def id2word(self):
if self.pickle_exist("id2word.pkl"):
file_name = os.path.join(self.pickle_path, "id2word.pkl")
id2word_dict = _pickle.load(open(file_name, "rb"))
return len(id2word_dict)
# nothing will be done if id2word.pkl exists
id2word_dict = {}
for word in self.word_dict:
id2word_dict[self.word_dict[word]] = word
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2word_dict, f)
return len(id2word_dict)
def class2id(self):
if self.pickle_exist("class2id.pkl"):
return
# nothing will be done if class2id.pkl exists
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.label_dict, f)
def id2class(self):
if self.pickle_exist("id2class.pkl"):
file_name = os.path.join(self.pickle_path, "id2class.pkl")
id2class_dict = _pickle.load(open(file_name, "rb"))
return len(id2class_dict)
# nothing will be done if id2class.pkl exists
id2class_dict = {}
for label in self.label_dict:
id2class_dict[self.label_dict[label]] = label
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2class_dict, f)
return len(id2class_dict)
def embedding(self):
if self.pickle_exist("embedding.pkl"):
return
# nothing will be done if embedding.pkl exists
def data_train(self):
if self.pickle_exist("data_train.pkl"):
return
# nothing will be done if data_train.pkl exists
def build_reverse_dict(self, word_dict):
id2word = {word_dict[w]: w for w in word_dict}
return id2word
def to_index(self, data):
"""
Convert word strings and label strings into indices.
:param data: list of str. Each string is a line, described above.
:return data_index: list of tuple (word index, label index)
"""
data_train = []
sentence = []
for w in self.data:
for w in data:
w = w.strip()
if len(w) <= 1:
wid = []
lid = []
for i in range(len(sentence)):
# if sentence[i][0]=="":
# print("")
wid.append(self.word_dict[sentence[i][0]])
lid.append(self.label_dict[sentence[i][1]])
wid.append(self.word2index[sentence[i][0]])
lid.append(self.label2index[sentence[i][1]])
data_train.append((wid, lid))
sentence = []
continue
sentence.append(w.split('\t'))
file_name = os.path.join(self.pickle_path, "data_train.pkl")
with open(file_name, "wb") as f:
_pickle.dump(data_train, f)
def data_dev(self):
pass
def data_test(self):
pass
return data_train
class ClassPreprocess(BasePreprocess):


+ 1
- 1
fastNLP/modules/aggregation/kmax_pool.py View File

@@ -9,7 +9,7 @@ import torch.nn as nn
class KMaxPool(nn.Module):
"""K max-pooling module."""

def __init__(self, k):
def __init__(self, k=1):
super(KMaxPool, self).__init__()
self.k = k



+ 0
- 9
fastNLP/modules/aggregation/linear_attention.py View File

@@ -1,9 +0,0 @@
from fastNLP.modules.aggregation.attention import Attention


class LinearAttention(Attention):
def __init__(self, normalize=False):
super(LinearAttention, self).__init__(normalize)

def _atten_forward(self, query, memory):
raise NotImplementedError

+ 7
- 6
fastNLP/modules/aggregation/self_attention.py View File

@@ -8,14 +8,15 @@ class SelfAttention(nn.Module):
Self Attention Module.

Args:
input_size : the size for the input vector
d_a : the width of weight matrix
r : the number of encoded vectors
input_size: int, the size for the input vector
dim: int, the width of weight matrix.
num_vec: int, the number of encoded vectors
"""
def __init__(self, input_size, d_a, r):

def __init__(self, input_size, dim=10, num_vec=10):
super(SelfAttention, self).__init__()
self.W_s1 = nn.Parameter(torch.randn(d_a, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(r, d_a), requires_grad=True)
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
self.softmax = nn.Softmax(dim=2)
self.tanh = nn.Tanh()



+ 20
- 12
fastNLP/modules/encoder/char_embedding.py View File

@@ -5,13 +5,15 @@ from torch import nn

class ConvCharEmbedding(nn.Module):

def __init__(self, char_emb_size, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
"""
Character Level Word Embedding
:param char_emb_size: the size of character level embedding,
:param char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param feature_maps: table of feature maps (for each kernel width)
:param kernels: table of kernel widths
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param kernels: tuple of int. The width of each kernel.
"""
super(ConvCharEmbedding, self).__init__()
self.convs = nn.ModuleList([
@@ -23,29 +25,35 @@ class ConvCharEmbedding(nn.Module):
:param x: [batch_size * sent_length, word_length, char_emb_size]
:return: [batch_size * sent_length, sum(feature_maps), 1]
"""
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) # [batch_size*sent_length, channel, width, height]
x = x.transpose(2, 3) # [batch_size*sent_length, channel, height, width]
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height]
x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width]
return self.convolute(x).unsqueeze(2)

def convolute(self, x):
feats = []
for conv in self.convs:
y = conv(x) # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = conv(x)
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y)
y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]]
y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]]
feats.append(y)
return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)]


class LSTMCharEmbedding(nn.Module):
"""
Character Level Word Embedding with LSTM
:param char_emb_size: the size of character level embedding,
Character Level Word Embedding with LSTM with a single layer.
:param char_emb_size: int, the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
"""

def __init__(self, char_emb_size, hidden_size=None):
def __init__(self, char_emb_size=50, hidden_size=None):
super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size



+ 3
- 0
fastNLP/modules/encoder/conv.py View File

@@ -2,12 +2,14 @@
# encoding: utf-8

import torch.nn as nn
from torch.nn.init import xavier_uniform
# import torch.nn.functional as F


class Conv(nn.Module):
"""
Basic 1-d convolution module.
initialize with xavier_uniform
"""

def __init__(self, in_channels, out_channels, kernel_size,
@@ -23,6 +25,7 @@ class Conv(nn.Module):
dilation=dilation,
groups=groups,
bias=bias)
xavier_uniform(self.conv.weight)

def forward(self, x):
return self.conv(x) # [N,C,L]

+ 3
- 2
fastNLP/modules/encoder/embedding.py View File

@@ -7,12 +7,13 @@ class Lookuptable(nn.Module):

Args:
nums : the size of the lookup table
dims : the size of each vector
dims : the size of each vector. Default: 50.
padding_idx : pads the tensor with zeros whenever it encounters this index
sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False):

def __init__(self, nums, dims=50, padding_idx=0, sparse=False):
super(Lookuptable, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)


+ 5
- 4
fastNLP/modules/encoder/lstm.py View File

@@ -8,11 +8,12 @@ class Lstm(nn.Module):
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
"""
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional):

def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5, bidirectional=False):
super(Lstm, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional)


Loading…
Cancel
Save