| @@ -18,8 +18,9 @@ class SampleIter: | |||||
| def __iter__(self): | def __iter__(self): | ||||
| for f in os.listdir(self.dirname): | for f in os.listdir(self.dirname): | ||||
| for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): | |||||
| yield x, y | |||||
| with open(os.path.join(self.dirname, f), 'rb') as f: | |||||
| for y, x in pickle.load(f): | |||||
| yield x, y | |||||
| class SentIter: | class SentIter: | ||||
| def __init__(self, dirname, count, vocab=None): | def __init__(self, dirname, count, vocab=None): | ||||
| @@ -29,17 +30,18 @@ class SentIter: | |||||
| def __iter__(self): | def __iter__(self): | ||||
| for f in os.listdir(self.dirname)[:self.count]: | for f in os.listdir(self.dirname)[:self.count]: | ||||
| for y, x in pickle.load(open(os.path.join(self.dirname, f), 'rb')): | |||||
| for sent in x: | |||||
| if self.vocab is not None: | |||||
| _sent = [] | |||||
| for w in sent: | |||||
| if w in self.vocab: | |||||
| _sent.append(w) | |||||
| else: | |||||
| _sent.append(UNK_token) | |||||
| sent = _sent | |||||
| yield sent | |||||
| with open(os.path.join(self.dirname, f), 'rb') as f: | |||||
| for y, x in pickle.load(f): | |||||
| for sent in x: | |||||
| if self.vocab is not None: | |||||
| _sent = [] | |||||
| for w in sent: | |||||
| if w in self.vocab: | |||||
| _sent.append(w) | |||||
| else: | |||||
| _sent.append(UNK_token) | |||||
| sent = _sent | |||||
| yield sent | |||||
| def train_word_vec(): | def train_word_vec(): | ||||
| # load data | # load data | ||||
| @@ -69,31 +71,37 @@ class Embedding_layer: | |||||
| from torch.utils.data import DataLoader, Dataset | from torch.utils.data import DataLoader, Dataset | ||||
| class YelpDocSet(Dataset): | class YelpDocSet(Dataset): | ||||
| def __init__(self, dirname, num_files, embedding): | |||||
| def __init__(self, dirname, start_file, num_files, embedding): | |||||
| self.dirname = dirname | self.dirname = dirname | ||||
| self.num_files = num_files | self.num_files = num_files | ||||
| self._len = num_files*5000 | |||||
| self._files = os.listdir(dirname)[:num_files] | |||||
| self._files = os.listdir(dirname)[start_file:start_file + num_files] | |||||
| self.embedding = embedding | self.embedding = embedding | ||||
| self._cache = [(-1, None) for i in range(5)] | |||||
| def __len__(self): | def __len__(self): | ||||
| return self._len | |||||
| return len(self._files)*5000 | |||||
| def __getitem__(self, n): | def __getitem__(self, n): | ||||
| file_id = n // 5000 | file_id = n // 5000 | ||||
| sample_list = pickle.load(open( | |||||
| os.path.join(self.dirname, self._files[file_id]), 'rb')) | |||||
| y, x = sample_list[n % 5000] | |||||
| idx = file_id % 5 | |||||
| if self._cache[idx][0] != file_id: | |||||
| print('load {} to {}'.format(file_id, idx)) | |||||
| with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f: | |||||
| self._cache[idx] = (file_id, pickle.load(f)) | |||||
| y, x = self._cache[idx][1][n % 5000] | |||||
| doc = [] | doc = [] | ||||
| for sent in x: | for sent in x: | ||||
| if len(sent) == 0: | |||||
| continue | |||||
| sent_vec = [] | sent_vec = [] | ||||
| for word in sent: | for word in sent: | ||||
| vec = self.embedding.get_vec(word) | vec = self.embedding.get_vec(word) | ||||
| vec = torch.Tensor(vec.reshape((1, -1))) | |||||
| sent_vec.append(vec) | |||||
| sent_vec = torch.cat(sent_vec, dim=0) | |||||
| sent_vec.append(vec.tolist()) | |||||
| sent_vec = torch.Tensor(sent_vec) | |||||
| # print(sent_vec.size()) | # print(sent_vec.size()) | ||||
| doc.append(sent_vec) | doc.append(sent_vec) | ||||
| if len(doc) == 0: | |||||
| doc = [torch.zeros(1,200)] | |||||
| return doc, y-1 | return doc, y-1 | ||||
| def collate(iterable): | def collate(iterable): | ||||
| @@ -105,7 +113,7 @@ def collate(iterable): | |||||
| return x_list, torch.LongTensor(y_list) | return x_list, torch.LongTensor(y_list) | ||||
| def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): | def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): | ||||
| optimizer = torch.optim.SGD(net.parameters(), lr=0.01) | |||||
| optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |||||
| criterion = nn.NLLLoss() | criterion = nn.NLLLoss() | ||||
| dataloader = DataLoader(dataset, | dataloader = DataLoader(dataset, | ||||
| @@ -116,6 +124,7 @@ def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False): | |||||
| if use_cuda: | if use_cuda: | ||||
| net.cuda() | net.cuda() | ||||
| print('start training') | |||||
| for epoch in range(num_epoch): | for epoch in range(num_epoch): | ||||
| for i, batch_samples in enumerate(dataloader): | for i, batch_samples in enumerate(dataloader): | ||||
| x, y = batch_samples | x, y = batch_samples | ||||
| @@ -157,11 +166,14 @@ if __name__ == '__main__': | |||||
| embed_model = Word2Vec.load('yelp.word2vec') | embed_model = Word2Vec.load('yelp.word2vec') | ||||
| embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) | ||||
| del embed_model | del embed_model | ||||
| dataset = YelpDocSet('reviews', 120, embedding) | |||||
| # for start_file in range(11, 24): | |||||
| start_file = 0 | |||||
| dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding) | |||||
| print('start_file %d'% start_file) | |||||
| print(len(dataset)) | |||||
| net = HAN(input_size=200, output_size=5, | net = HAN(input_size=200, output_size=5, | ||||
| word_hidden_size=50, word_num_layers=1, word_context_size=100, | word_hidden_size=50, word_num_layers=1, word_context_size=100, | ||||
| sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) | ||||
| # net.load_state_dict(torch.load('model.dict')) | |||||
| net.load_state_dict(torch.load('model.dict')) | |||||
| train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True) | train(net, dataset, num_epoch=1, batch_size=64, use_cuda=True) | ||||