|
|
@@ -18,8 +18,9 @@ class SampleIter: |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for f in os.listdir(self.dirname): |
|
|
|
for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): |
|
|
|
yield x, y |
|
|
|
with open(os.path.join(self.dirname, f), 'rb') as f: |
|
|
|
for y, x in pickle.load(f): |
|
|
|
yield x, y |
|
|
|
|
|
|
|
class SentIter: |
|
|
|
def __init__(self, dirname, count, vocab=None): |
|
|
@@ -29,17 +30,18 @@ class SentIter: |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for f in os.listdir(self.dirname)[:self.count]: |
|
|
|
for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): |
|
|
|
for sent in x: |
|
|
|
if self.vocab is not None: |
|
|
|
_sent = [] |
|
|
|
for w in sent: |
|
|
|
if w in self.vocab: |
|
|
|
_sent.append(w) |
|
|
|
else: |
|
|
|
_sent.append(UNK_token) |
|
|
|
sent = _sent |
|
|
|
yield sent |
|
|
|
with open(os.path.join(self.dirname, f), 'rb') as f: |
|
|
|
for y, x in pickle.load(f): |
|
|
|
for sent in x: |
|
|
|
if self.vocab is not None: |
|
|
|
_sent = [] |
|
|
|
for w in sent: |
|
|
|
if w in self.vocab: |
|
|
|
_sent.append(w) |
|
|
|
else: |
|
|
|
_sent.append(UNK_token) |
|
|
|
sent = _sent |
|
|
|
yield sent |
|
|
|
|
|
|
|
def train_word_vec(): |
|
|
|
# load data |
|
|
@@ -69,31 +71,37 @@ class Embedding_layer: |
|
|
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
class YelpDocSet(Dataset): |
|
|
|
def __init__(self, dirname, num_files, embedding): |
|
|
|
def __init__(self, dirname, start_file, num_files, embedding): |
|
|
|
self.dirname = dirname |
|
|
|
self.num_files = num_files |
|
|
|
self._len = num_files*5000 |
|
|
|
self._files = os.listdir(dirname)[:num_files] |
|
|
|
self._files = os.listdir(dirname)[start_file:start_file + num_files] |
|
|
|
self.embedding = embedding |
|
|
|
|
|
|
|
self._cache = [(-1, None) for i in range(5)] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return self._len |
|
|
|
return len(self._files)*5000 |
|
|
|
|
|
|
|
def __getitem__(self, n): |
|
|
|
file_id = n // 5000 |
|
|
|
sample_list = pickle.load(open( |
|
|
|
os.path.join(self.dirname, self._files[file_id]), 'rb')) |
|
|
|
y, x = sample_list[n % 5000] |
|
|
|
idx = file_id % 5 |
|
|
|
if self._cache[idx][0] != file_id: |
|
|
|
print('load {} to {}'.format(file_id, idx)) |
|
|
|
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f: |
|
|
|
self._cache[idx] = (file_id, pickle.load(f)) |
|
|
|
y, x = self._cache[idx][1][n % 5000] |
|
|
|
doc = [] |
|
|
|
for sent in x: |
|
|
|
if len(sent) == 0: |
|
|
|
continue |
|
|
|
sent_vec = [] |
|
|
|
for word in sent: |
|
|
|
vec = self.embedding.get_vec(word) |
|
|
|
vec = torch.Tensor(vec.reshape((1, -1))) |
|
|
|
sent_vec.append(vec) |
|
|
|
sent_vec = torch.cat(sent_vec, dim=0) |
|
|
|
sent_vec.append(vec.tolist()) |
|
|
|
sent_vec = torch.Tensor(sent_vec) |
|
|
|
# print(sent_vec.size()) |
|
|
|
doc.append(sent_vec) |
|
|
|
if len(doc) == 0: |
|
|
|
doc = [torch.zeros(1,200)] |
|
|
|
return doc, y-1 |
|
|
|
|
|
|
|
def collate(iterable): |
|
|
@@ -105,7 +113,7 @@ def collate(iterable): |
|
|
|
return x_list, torch.LongTensor(y_list) |
|
|
|
|
|
|
|
def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): |
|
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.01) |
|
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) |
|
|
|
criterion = nn.NLLLoss() |
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, |
|
|
@@ -116,6 +124,7 @@ def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
net.cuda() |
|
|
|
print('start training') |
|
|
|
for epoch in range(num_epoch): |
|
|
|
for i, batch_samples in enumerate(dataloader): |
|
|
|
x, y = batch_samples |
|
|
@@ -157,11 +166,14 @@ if __name__ == '__main__': |
|
|
|
embed_model = Word2Vec.load('yelp.word2vec') |
|
|
|
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) |
|
|
|
del embed_model |
|
|
|
dataset = YelpDocSet('reviews', 120, embedding) |
|
|
|
|
|
|
|
# for start_file in range(11, 24): |
|
|
|
start_file = 0 |
|
|
|
dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding) |
|
|
|
print('start_file %d'% start_file) |
|
|
|
print(len(dataset)) |
|
|
|
net = HAN(input_size=200, output_size=5, |
|
|
|
word_hidden_size=50, word_num_layers=1, word_context_size=100, |
|
|
|
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) |
|
|
|
# net.load_state_dict(torch.load('model.dict')) |
|
|
|
|
|
|
|
|
|
|
|
net.load_state_dict(torch.load('model.dict')) |
|
|
|
train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True) |