Browse Source

add evaluate

tags/v0.1.0
choocewhatulike 6 years ago
parent
commit
049420c207
3 changed files with 78 additions and 2 deletions
  1. +61
    -0
      model_inplement/code/evaluate.py
  2. +2
    -0
      model_inplement/code/model.py
  3. +15
    -2
      model_inplement/code/train.py

+ 61
- 0
model_inplement/code/evaluate.py View File

@@ -0,0 +1,61 @@
from model import *
from train import *

def evaluate(net, dataset, bactch_size=64, use_cuda=False):
dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
count = 0
if use_cuda:
net.cuda()
for i, batch_samples in enumerate(dataloader):
x, y = batch_samples
doc_list = []
for sample in x:
doc = []
for sent_vec in sample:
# print(sent_vec.size())
if use_cuda:
sent_vec = sent_vec.cuda()
doc.append(Variable(sent_vec, volatile=True))
doc_list.append(pack_sequence(doc))
if use_cuda:
y = y.cuda()
predicts = net(doc_list)
# idx = []
# for p in predicts.data:
# idx.append(np.random.choice(5, p=torch.exp(p).numpy()))
# idx = torch.LongTensor(idx)
p, idx = torch.max(predicts, dim=1)
idx = idx.data
count += torch.sum(torch.eq(idx, y))
return count

def visualize_attention(doc, alpha_vec):
pass

if __name__ == '__main__':
from gensim.models import Word2Vec
import gensim
from gensim import models
embed_model = Word2Vec.load('yelp.word2vec')
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
del embed_model

net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
net.load_state_dict(torch.load('model.dict'))
test_dataset = YelpDocSet('reviews', 199, 4, embedding)
correct = evaluate(net, test_dataset, True)
print('accuracy {}'.format(correct/len(test_dataset)))
# data_idx = 121
# x, y = test_dataset[data_idx]
# doc = []
# for sent_vec in x:
# doc.append(Variable(sent_vec, volatile=True))
# input_vec = [pack_sequence(doc)]
# predict = net(input_vec)
# p, idx = torch.max(predict, dim=1)
# print(net.word_layer.last_alpha.squeeze())
# print(net.sent_layer.last_alpha)
# print(test_dataset.get_doc(data_idx)[0])
# print('predict: {}, true: {}'.format(int(idx), y))

+ 2
- 0
model_inplement/code/model.py View File

@@ -55,6 +55,7 @@ class AttentionNet(nn.Module):
self.gru_hidden_size = gru_hidden_size
self.gru_num_layers = gru_num_layers
self.context_vec_size = context_vec_size
self.last_alpha = None

# Encoder
self.gru = nn.GRU(input_size=input_size,
@@ -76,6 +77,7 @@ class AttentionNet(nn.Module):
u = self.tanh(self.fc(h_t))
# u's dim (batch_size, seq_len, context_vec_size)
alpha = self.softmax(torch.matmul(u, self.context_vec))
self.last_alpha = alpha.data
# alpha's dim (batch_size, seq_len, 1)
output = torch.bmm(torch.transpose(h_t, 1, 2), alpha)
# output's dim (batch_size, 2*hidden_size, 1)


+ 15
- 2
model_inplement/code/train.py View File

@@ -78,6 +78,20 @@ class YelpDocSet(Dataset):
self.embedding = embedding
self._cache = [(-1, None) for i in range(5)]

def get_doc(self, n):
file_id = n // 5000
idx = file_id % 5
if self._cache[idx][0] != file_id:
print('load {} to {}'.format(file_id, idx))
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
self._cache[idx] = (file_id, pickle.load(f))
y, x = self._cache[idx][1][n % 5000]
sents = []
for s_list in x:
sents.append(' '.join(s_list))
x = '\n'.join(sents)
return x, y-1

def __len__(self):
return len(self._files)*5000

@@ -166,7 +180,6 @@ if __name__ == '__main__':
embed_model = Word2Vec.load('yelp.word2vec')
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
del embed_model
# for start_file in range(11, 24):
start_file = 0
dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding)
print('start_file %d'% start_file)
@@ -176,4 +189,4 @@ if __name__ == '__main__':
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
net.load_state_dict(torch.load('model.dict'))
train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True)
train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True)

Loading…
Cancel
Save