From f3b1e6d72a82f2d3ffe1722cf99d316fd28aa766 Mon Sep 17 00:00:00 2001 From: choocewhatulike <1901722105@qq.com> Date: Mon, 12 Mar 2018 22:41:38 +0800 Subject: [PATCH] add gpu --- .../code/__pycache__/model.cpython-36.pyc | Bin 2342 -> 0 bytes model_inplement/code/train.py | 91 ++++++++++-------- 2 files changed, 50 insertions(+), 41 deletions(-) delete mode 100644 model_inplement/code/__pycache__/model.cpython-36.pyc diff --git a/model_inplement/code/__pycache__/model.cpython-36.pyc b/model_inplement/code/__pycache__/model.cpython-36.pyc deleted file mode 100644 index 167f6ff9d4401d8077c23d66909e6c96b6049a11..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2342 zcmb7FOOM<{5bkb2$1^*bY$g$p2qEztAr|pC00Kf3l;jXbD-wcaEwa3}XLl@*XVY$z z&B}9{NI8W&`~W0=0w?}QUpeJ3aN?`BXPFfz+Un}6>UMX#>Z{Km@9YHcefe9s-)HPE zw$*ItU!zwcM9NZLu&^C@$c-&>E3}}A!d^I`!`oRmw{kD^@37QL?WZiYGb`_{tuTPc zNnLo{&eKo%eKzr4V1G89aP2>g-*S$!igH9V&4?yiqv|JouVLRIM$xN5ZL_%hrriZ>3lKiYp1GDGAZeB?MKnP zoU16(gPTfarJ66xX{L1Wcq!AUh@WIq2NnFCF3M87_vU35OFdZD>ZIPx1oxIl_m?xZ zh)=awwaTQY?J6r~I+&NVvYMY|y5GHKjFCC#c%G(N*-l?Fp0ZxFCsz8haU{!9WvAU1 zBXgq*ZyC?_TgI{d76QzP$Sw@P%Ye5`WNU0?_Mta?ok*s!Lug@tr+jftn-_~axjQ%P6Uj#|p(SX)V~ zB+8J<5s4wK4CG}DVec?j$#Ik}lW+hL{gfq9F|QO*2A5m42GU=u(5YxzTM(hwE!J*} z1NkcLY{YeF+M6xq<5;FI!|38`^MF$y^C9FA&xoJ(DZbs>iAz8(<3$D_2T_FV)dlH; zD0)=KMLT2CCaGriP6vb}D~c%k_X-m(U!m*JO-zAMrYqtdpASaHlfXD{R72GM2NOsF zPC^)e~h*un3rdbd{I;m=g_{!PGqR8};A~Qo9l01lw>z zAD+P)7>ahF4T|W z{&v9yTea42p)i|Eq<`mdp3Y^K5R=4(q}Xi(P0|-1{rU^FKW=Q4KPT5Dg~0@;&67Fu z(vhsA?PG`J*gAT&sc?cHF^U|}RjoVUeXlL7l{;jLS)%O+ae1t5Mbhht%Ai{J_G1a? zC~)FywAy`;mDN(}{?^{wPGc2ozb@y%+#*6us0vJAYW zm>J*)JV7jgT!R5fcQNk~nsQ_tKF8PsKSl<=F(dfJ1qQ^_Bn^dc!qcnE>QS9#i1mlK z0LZ?xK7l1aB0*%K#bO~pC55YrXQHKD^Bbr_>o`){ZS$&mt(~|yIgaHU^zKDEFy(p! z{))JVi2;g#fbt*nu{axFh}Ah4$hT;pw@JK1f@;|8Bi|+M{D4iOOtDi{()6a-uNf## z2HFAIK>6m=eZz`-W!a)^*qU*zmPay9C6Ro1WwT~)R+n8B=wM6dx3&ANAt=KqE#%KH bwZ(+~7cGx{MLb;*u6^MHCj&kZyRZKPKx06? diff --git a/model_inplement/code/train.py b/model_inplement/code/train.py index ae7ee925..4b22f69a 100644 --- a/model_inplement/code/train.py +++ b/model_inplement/code/train.py @@ -1,8 +1,14 @@ -import gensim -from gensim import models - import os import pickle + +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker + +import nltk +import numpy as np +import torch + +from model import * class SampleIter: def __init__(self, dirname): @@ -32,35 +38,6 @@ def train_word_vec(): model = models.Word2Vec(sentences=sents, size=200, sg=0, workers=4, min_count=5) model.save('yelp.word2vec') - -''' -Train process -''' -import math -import os -import copy -import pickle - -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import numpy as np -import json -import nltk -from gensim.models import Word2Vec -import torch -from torch.utils.data import DataLoader, Dataset - -from model import * - -net = HAN(input_size=200, output_size=5, - word_hidden_size=50, word_num_layers=1, word_context_size=100, - sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) - -optimizer = torch.optim.SGD(net.parameters(), lr=0.01) -criterion = nn.NLLLoss() -num_epoch = 1 -batch_size = 64 - class Embedding_layer: def __init__(self, wv, vector_size): self.wv = wv @@ -73,10 +50,8 @@ class Embedding_layer: v = np.zeros(self.vector_size) return v -embed_model = Word2Vec.load('yelp.word2vec') -embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) -del embed_model +from torch.utils.data import DataLoader, Dataset class YelpDocSet(Dataset): def __init__(self, dirname, num_files, embedding): self.dirname = dirname @@ -103,12 +78,28 @@ def collate(iterable): x_list.append(x) return x_list, torch.LongTensor(y_list) -if __name__ == '__main__': +def train(net, num_epoch, batch_size, print_size=10, use_cuda=False): + from gensim.models import Word2Vec + import torch + import gensim + from gensim import models + + embed_model = Word2Vec.load('yelp.word2vec') + embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size) + del embed_model + + optimizer = torch.optim.SGD(net.parameters(), lr=0.01) + criterion = nn.NLLLoss() + dirname = 'reviews' - dataloader = DataLoader(YelpDocSet(dirname, 238, embedding), batch_size=batch_size, collate_fn=collate) + dataloader = DataLoader(YelpDocSet(dirname, 238, embedding), + batch_size=batch_size, + collate_fn=collate, + num_workers=4) running_loss = 0.0 - print_size = 10 + if use_cuda: + net.cuda() for epoch in range(num_epoch): for i, batch_samples in enumerate(dataloader): x, y = batch_samples @@ -119,11 +110,16 @@ if __name__ == '__main__': sent_vec = [] for word in sent: vec = embedding.get_vec(word) - sent_vec.append(torch.Tensor(vec.reshape((1, -1)))) + vec = torch.Tensor(vec.reshape((1, -1))) + if use_cuda: + vec = vec.cuda() + sent_vec.append(vec) sent_vec = torch.cat(sent_vec, dim=0) # print(sent_vec.size()) doc.append(Variable(sent_vec)) doc_list.append(doc) + if use_cuda: + y = y.cuda() y = Variable(y) predict = net(doc_list) loss = criterion(predict, y) @@ -131,8 +127,21 @@ if __name__ == '__main__': loss.backward() optimizer.step() running_loss += loss.data[0] - print(loss.data[0]) if i % print_size == print_size-1: print(running_loss/print_size) running_loss = 0.0 - + torch.save(net.state_dict(), 'model.dict') + torch.save(net.state_dict(), 'model.dict') + + +if __name__ == '__main__': + ''' + Train process + ''' + + + net = HAN(input_size=200, output_size=5, + word_hidden_size=50, word_num_layers=1, word_context_size=100, + sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) + + train(net, num_epoch=1, batch_size=64, use_cuda=True)