Browse Source

update and add readme

tags/v0.1.0
HENRY L 7 years ago
parent
commit
561305e03d
8 changed files with 146 additions and 38 deletions
  1. +41
    -0
      fastNLP/modules/prototype/README.md
  2. +10
    -9
      fastNLP/modules/prototype/Word2Idx.py
  3. +2
    -3
      fastNLP/modules/prototype/aggregation.py
  4. +6
    -7
      fastNLP/modules/prototype/dataloader.py
  5. +0
    -3
      fastNLP/modules/prototype/encoder.py
  6. +36
    -15
      fastNLP/modules/prototype/example.py
  7. +1
    -1
      fastNLP/modules/prototype/predict.py
  8. +50
    -0
      fastNLP/modules/prototype/prepare.py

+ 41
- 0
fastNLP/modules/prototype/README.md View File

@@ -0,0 +1,41 @@
# Prototype

## Word2Idx.py
A mapping model between words and indexes

## embedding.py
embedding modules

Contains a simple encapsulation for torch.nn.Embedding

## encoder.py
encoder modules

Contains a simple encapsulation for torch.nn.LSTM

## aggregation.py
aggregation modules

Contains a self-attention model, according to paper "A Structured Self-attentive Sentence Embedding", https://arxiv.org/abs/1703.03130

## predict.py
predict modules

Contains a two layers perceptron for classification

## example.py
An example showing how to use above modules to build a model

Contains a model for sentiment analysis on Yelp dataset, and its training and testing procedures. See https://arxiv.org/abs/1703.03130 for more details.

## prepare.py
A case of using Word2Idx to build Yelp datasets

## dataloader.py
A dataloader for Yelp dataset

It is an iterable object, returning a zero-padded batch every iteration.





+ 10
- 9
fastNLP/modules/prototype/Word2Idx.py View File

@@ -4,15 +4,15 @@ import pickle
class Word2Idx():
"""
Build a word index according to word frequency.

If "min_freq" is given, then only words with a frequncy not lesser than min_freq will be kept.
If "max_num" is given, then at most the most frequent $max_num words will be kept.
"words" should be a list [ w_1,w_2,...,w_i,...,w_n ] where each w_i is a string representing a word.
num is the size of the lookup table.
w2i is a lookup table assigning each word an index.
Note that index 0 will be returned for any unregistered words.
i2w is a vector which serves as an invert mapping of w2i.
Token "<UNK>" will be returned for index 0
Note that index 0 is token "<PAD>" for padding
index 1 is token "<UNK>" for unregistered words
e.g. i2w[w2i["word"]] == "word"
"""
def __init__(self):
@@ -29,29 +29,30 @@ class Word2Idx():
else:
most_common = counter.most_common()
self.__w2i = dict((w[0],i + 1) for i,w in enumerate(most_common) if w[1] >= min_freq)
self.__w2i["<UNK>"] = 0
self.__i2w = ["<UNK>"] + [ w[0] for w in most_common if w[1] >= min_freq ]
self.__w2i["<PAD>"] = 0
self.__w2i["<UNK>"] = 1
self.__i2w = ["<PAD>", "<UNK>"] + [ w[0] for w in most_common if w[1] >= min_freq ]
self.num = len(self.__i2w)

def w2i(self,word):
def w2i(self, word):
"""word to index"""
if word in self.__w2i:
return self.__w2i[word]
return 0

def i2w(self,idx):
def i2w(self, idx):
"""index to word"""
if idx >= self.num:
raise Exception("out of range\n")
return self.__i2w[idx]

def save(self,addr):
def save(self, addr):
"""save the model to a file with address "addr" """
f = open(addr,"wb")
pickle.dump([self.__i2w, self.__w2i, self.num], f)
f.close()

def load(self,addr):
def load(self, addr):
"""load a model from a file with address "addr" """
f = open(addr,"rb")
paras = pickle.load(f)


+ 2
- 3
fastNLP/modules/prototype/aggregation.py View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

class Selfattention(nn.Module):
"""
@@ -32,10 +33,8 @@ class Selfattention(nn.Module):
def forward(self, x):
inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2)))
A = self.softmax(torch.matmul(self.W_s2, inter))
out = torch.matmul(A, H)
out = torch.matmul(A, x)
out = out.view(out.size(0), -1)
penalty = self.penalization(A)
return out, penalty

if __name__ == "__main__":
model = Selfattention(100, 10, 20)

+ 6
- 7
fastNLP/modules/prototype/dataloader.py View File

@@ -32,10 +32,10 @@ def pad(X, using_cuda):
padlen = maxlen - x.size(0)
if padlen > 0:
if using_cuda:
paddings = torch.zeros(padlen).cuda()
paddings = Variable(torch.zeros(padlen).long()).cuda()
else:
paddings = torch.zeros(padlen)
x_ = torch.cat(x, paddings)
paddings = Variable(torch.zeros(padlen).long())
x_ = torch.cat((x, paddings), 0)
Y.append(x_)
else:
Y.append(x)
@@ -71,12 +71,11 @@ class DataLoader(object):
random.shuffle(self.data)
raise StopIteration()
else:
X = self.data[self.count * self.batch_size : (self.count + 1) * self.batch_size]
batch = self.data[self.count * self.batch_size : (self.count + 1) * self.batch_size]
self.count += 1
X = [long_wrapper(x["sent"], using_cuda=self.using_cuda) for x in X]
X = [long_wrapper(x["sent"], using_cuda=self.using_cuda, requires_grad=False) for x in batch]
X = pad(X, self.using_cuda)
y = [long_wrapper(x["class"], using_cuda=self.using_cuda) for x in X]
y = torch.stack(y)
y = long_wrapper([x["class"] for x in batch], using_cuda=self.using_cuda, requires_grad=False)
return {"feature" : X, "class" : y}


+ 0
- 3
fastNLP/modules/prototype/encoder.py View File

@@ -20,6 +20,3 @@ class Lstm(nn.Module):
def forward(self, x):
x, _ = self.lstm(x)
return x

if __name__ == "__main__":
model = Lstm(20, 30, 1, 0.5, False)

+ 36
- 15
fastNLP/modules/prototype/example.py View File

@@ -8,13 +8,13 @@ import torch.optim as optim
import time
import dataloader

WORD_NUM = 357361
WORD_SIZE = 100
HIDDEN_SIZE = 300
D_A = 350
R = 20
R = 10
MLP_HIDDEN = 2000
CLASSES_NUM = 5
WORD_NUM = 357361

class Net(nn.Module):
"""
@@ -32,7 +32,7 @@ class Net(nn.Module):
x = self.encoder(x)
x, penalty = self.aggregation(x)
x = self.predict(x)
return r, x
return x, penalty

def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10):
@@ -50,7 +50,7 @@ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
the result will be saved with a form "model_dict_+current time", which could be used for further training
"""
if using_cuda == True:
if using_cuda:
net = Net().cuda()
else:
net = Net()
@@ -60,7 +60,7 @@ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\

optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
criterion = nn.CrossEntropyLoss()
dataset = dataloader.DataLoader("trainset.pkl", using_cuda=using_cuda)
dataset = dataloader.DataLoader("test_set.pkl", batch_size, using_cuda=using_cuda)

#statistics
loss_count = 0
@@ -69,6 +69,7 @@ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
count = 0

for epoch in range(epochs):
print("epoch: %d"%(epoch))
for i, batch in enumerate(dataset):
t1 = time.time()
X = batch["feature"]
@@ -86,23 +87,43 @@ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
loss_count += torch.sum(y_penl).data[0]
prepare_time += (t2 - t1)
run_time += (t3 - t2)
p, idx = torch.max(y_pred, dim=1)
idx = idx.data
count += torch.sum(torch.eq(idx.cpu(), y))
p, idx = torch.max(y_pred.data, dim=1)
count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))

if i % interval == 0:
print(i)
print("loss count:" + str(loss_count / batch_size))
print("acuracy:" + str(count / batch_size))
if (i + 1) % interval == 0:
print("epoch : %d, iters: %d"%(epoch, i + 1))
print("loss count:" + str(loss_count / (interval * batch_size)))
print("acuracy:" + str(count / (interval * batch_size)))
print("penalty:" + str(torch.sum(y_penl).data[0] / batch_size))
print("prepare time:" + str(prepare_time / batch_size))
print("run time:" + str(run_time / batch_size))
print("prepare time:" + str(prepare_time))
print("run time:" + str(run_time))
prepare_time = 0
run_time = 0
loss_count = 0
count = 0
torch.save(net.state_dict(), "model_dict_%s.pkl"%(str(time.time())))
string = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
torch.save(net.state_dict(), "model_dict_%s.dict"%(string))

def test(model_dict, using_cuda=True):
if using_cuda:
net = Net().cuda()
else:
net = Net()
net.load_state_dict(torch.load(model_dict))
dataset = dataloader.DataLoader("test_set.pkl", batch_size=1, using_cuda=using_cuda)
count = 0
for i, batch in enumerate(dataset):
X = batch["feature"]
y = batch["class"]
y_pred, _ = net(X)
p, idx = torch.max(y_pred.data, dim=1)
count += torch.sum(torch.eq(idx.cpu(), y.data.cpu()))
print("accuracy: %f"%(count / dataset.num))

if __name__ == "__main__":
train(using_cuda=torch.cuda.is_available())


+ 1
- 1
fastNLP/modules/prototype/predict.py View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
"""
@@ -15,7 +16,6 @@ class MLP(nn.Module):
super(MLP,self).__init__()
self.L1 = nn.Linear(input_size, hidden_size)
self.L2 = nn.Linear(hidden_size, output_size)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
out = self.L2(F.relu(self.L1(x)))


+ 50
- 0
fastNLP/modules/prototype/prepare.py View File

@@ -0,0 +1,50 @@
import pickle
import Word2Idx

def get_sets(m, n):
"""
get a train set containing m samples and a test set containing n samples
"""
samples = pickle.load(open("tuples.pkl","rb"))
if m+n > len(samples):
print("asking for too many tuples\n")
return
train_samples = samples[ : m]
test_samples = samples[m: m+n]
return train_samples, test_samples

def build_wordidx():
"""
build wordidx using word2idx
"""
train, test = get_sets(500000, 2000)
words = []
for x in train:
words += x[0]
wordidx = Word2Idx.Word2Idx()
wordidx.build(words)
print(wordidx.num)
print(wordidx.i2w(0))
wordidx.save("wordidx.pkl")

def build_sets():
"""
build train set and test set, transform word to index
"""
train, test = get_sets(500000, 2000)
wordidx = Word2Idx.Word2Idx()
wordidx.load("wordidx.pkl")
train_set = []
for x in train:
sent = [wordidx.w2i(w) for w in x[0]]
train_set.append({"sent" : sent, "class" : x[1]})
test_set = []
for x in test:
sent = [wordidx.w2i(w) for w in x[0]]
test_set.append({"sent" : sent, "class" : x[1]})
pickle.dump(train_set, open("train_set.pkl", "wb"))
pickle.dump(test_set, open("test_set.pkl", "wb"))

if __name__ == "__main__":
build_wordidx()
build_sets()

Loading…
Cancel
Save