Browse Source

remove useless code

tags/v0.1.0
choocewhatulike 6 years ago
parent
commit
f451754cd3
5 changed files with 11 additions and 41 deletions
  1. +3
    -15
      model_inplement/evaluate.py
  2. +5
    -0
      model_inplement/model.py
  3. +0
    -0
      model_inplement/preprocess.py
  4. +0
    -0
      model_inplement/readme.md
  5. +3
    -26
      model_inplement/train.py

model_inplement/code/evaluate.py → model_inplement/evaluate.py View File

@@ -29,10 +29,10 @@ def evaluate(net, dataset, bactch_size=64, use_cuda=False):
count += torch.sum(torch.eq(idx, y))
return count

def visualize_attention(doc, alpha_vec):
pass

if __name__ == '__main__':
'''
Evaluate the performance of model
'''
from gensim.models import Word2Vec
import gensim
from gensim import models
@@ -47,15 +47,3 @@ if __name__ == '__main__':
test_dataset = YelpDocSet('reviews', 199, 4, embedding)
correct = evaluate(net, test_dataset, True)
print('accuracy {}'.format(correct/len(test_dataset)))
# data_idx = 121
# x, y = test_dataset[data_idx]
# doc = []
# for sent_vec in x:
# doc.append(Variable(sent_vec, volatile=True))
# input_vec = [pack_sequence(doc)]
# predict = net(input_vec)
# p, idx = torch.max(predict, dim=1)
# print(net.word_layer.last_alpha.squeeze())
# print(net.sent_layer.last_alpha)
# print(test_dataset.get_doc(data_idx)[0])
# print('predict: {}, true: {}'.format(int(idx), y))

model_inplement/code/model.py → model_inplement/model.py View File

@@ -73,8 +73,10 @@ class AttentionNet(nn.Module):

def forward(self, inputs):
# inputs's dim (batch_size, seq_len, word_dim)
# GRU part
h_t, hidden = self.gru(inputs)
u = self.tanh(self.fc(h_t))
# Attention part
# u's dim (batch_size, seq_len, context_vec_size)
alpha = self.softmax(torch.matmul(u, self.context_vec))
self.last_alpha = alpha.data
@@ -85,6 +87,9 @@ class AttentionNet(nn.Module):


if __name__ == '__main__':
'''
Test the model correctness
'''
import numpy as np
use_cuda = True
net = HAN(input_size=200, output_size=5,

model_inplement/code/preprocess.py → model_inplement/preprocess.py View File


model_inplement/readme → model_inplement/readme.md View File


model_inplement/code/train.py → model_inplement/train.py View File

@@ -10,37 +10,16 @@ import torch

from model import *

UNK_token = '/unk'
class SampleIter:
def __init__(self, dirname):
self.dirname = dirname
def __iter__(self):
for f in os.listdir(self.dirname):
with open(os.path.join(self.dirname, f), 'rb') as f:
for y, x in pickle.load(f):
yield x, y

class SentIter:
def __init__(self, dirname, count, vocab=None):
def __init__(self, dirname, count):
self.dirname = dirname
self.count = int(count)
self.vocab = None

def __iter__(self):
for f in os.listdir(self.dirname)[:self.count]:
with open(os.path.join(self.dirname, f), 'rb') as f:
for y, x in pickle.load(f):
for sent in x:
if self.vocab is not None:
_sent = []
for w in sent:
if w in self.vocab:
_sent.append(w)
else:
_sent.append(UNK_token)
sent = _sent
yield sent

def train_word_vec():
@@ -50,7 +29,6 @@ def train_word_vec():
# define model and train
model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5)
model.build_vocab(sents)
sents.vocab = model.wv.vocab
model.train(sents, total_examples=model.corpus_count, epochs=10)
model.save('yelp.word2vec')
print(model.wv.similarity('woman', 'man'))
@@ -82,7 +60,7 @@ class YelpDocSet(Dataset):
file_id = n // 5000
idx = file_id % 5
if self._cache[idx][0] != file_id:
print('load {} to {}'.format(file_id, idx))
# print('load {} to {}'.format(file_id, idx))
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
self._cache[idx] = (file_id, pickle.load(f))
y, x = self._cache[idx][1][n % 5000]
@@ -182,8 +160,7 @@ if __name__ == '__main__':
del embed_model
start_file = 0
dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding)
print('start_file %d'% start_file)
print(len(dataset))
print('training data size {}'.format(len(dataset)))
net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)

Loading…
Cancel
Save