Browse Source

fix data loader

tags/v0.1.0
choocewhatulike 6 years ago
parent
commit
9b8d8c451e
1 changed files with 41 additions and 29 deletions
  1. +41
    -29
      model_inplement/code/train.py

+ 41
- 29
model_inplement/code/train.py View File

@@ -18,8 +18,9 @@ class SampleIter:
def __iter__(self):
for f in os.listdir(self.dirname):
for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
yield x, y
with open(os.path.join(self.dirname, f), 'rb') as f:
for y, x in pickle.load(f):
yield x, y

class SentIter:
def __init__(self, dirname, count, vocab=None):
@@ -29,17 +30,18 @@ class SentIter:

def __iter__(self):
for f in os.listdir(self.dirname)[:self.count]:
for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')):
for sent in x:
if self.vocab is not None:
_sent = []
for w in sent:
if w in self.vocab:
_sent.append(w)
else:
_sent.append(UNK_token)
sent = _sent
yield sent
with open(os.path.join(self.dirname, f), 'rb') as f:
for y, x in pickle.load(f):
for sent in x:
if self.vocab is not None:
_sent = []
for w in sent:
if w in self.vocab:
_sent.append(w)
else:
_sent.append(UNK_token)
sent = _sent
yield sent

def train_word_vec():
# load data
@@ -69,31 +71,37 @@ class Embedding_layer:

from torch.utils.data import DataLoader, Dataset
class YelpDocSet(Dataset):
def __init__(self, dirname, num_files, embedding):
def __init__(self, dirname, start_file, num_files, embedding):
self.dirname = dirname
self.num_files = num_files
self._len = num_files*5000
self._files = os.listdir(dirname)[:num_files]
self._files = os.listdir(dirname)[start_file:start_file + num_files]
self.embedding = embedding
self._cache = [(-1, None) for i in range(5)]

def __len__(self):
return self._len
return len(self._files)*5000

def __getitem__(self, n):
file_id = n // 5000
sample_list = pickle.load(open(
os.path.join(self.dirname, self._files[file_id]), 'rb'))
y, x = sample_list[n % 5000]
idx = file_id % 5
if self._cache[idx][0] != file_id:
print('load {} to {}'.format(file_id, idx))
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
self._cache[idx] = (file_id, pickle.load(f))
y, x = self._cache[idx][1][n % 5000]
doc = []
for sent in x:
if len(sent) == 0:
continue
sent_vec = []
for word in sent:
vec = self.embedding.get_vec(word)
vec = torch.Tensor(vec.reshape((1, -1)))
sent_vec.append(vec)
sent_vec = torch.cat(sent_vec, dim=0)
sent_vec.append(vec.tolist())
sent_vec = torch.Tensor(sent_vec)
# print(sent_vec.size())
doc.append(sent_vec)
if len(doc) == 0:
doc = [torch.zeros(1,200)]
return doc, y-1

def collate(iterable):
@@ -105,7 +113,7 @@ def collate(iterable):
return x_list, torch.LongTensor(y_list)

def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.NLLLoss()

dataloader = DataLoader(dataset,
@@ -116,6 +124,7 @@ def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):

if use_cuda:
net.cuda()
print('start training')
for epoch in range(num_epoch):
for i, batch_samples in enumerate(dataloader):
x, y = batch_samples
@@ -157,11 +166,14 @@ if __name__ == '__main__':
embed_model = Word2Vec.load('yelp.word2vec')
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
del embed_model
dataset = YelpDocSet('reviews', 120, embedding)

# for start_file in range(11, 24):
start_file = 0
dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding)
print('start_file %d'% start_file)
print(len(dataset))
net = HAN(input_size=200, output_size=5,
word_hidden_size=50, word_num_layers=1, word_context_size=100,
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
# net.load_state_dict(torch.load('model.dict'))
net.load_state_dict(torch.load('model.dict'))
train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True)

Loading…
Cancel
Save