|
@@ -40,7 +40,7 @@ test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
|
|
batch_size=batch_size, |
|
|
batch_size=batch_size, |
|
|
shuffle=False) |
|
|
shuffle=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#cnn |
|
|
|
|
|
|
|
|
cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings()) |
|
|
cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings()) |
|
|
if cuda: |
|
|
if cuda: |
|
@@ -51,6 +51,8 @@ if cuda: |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) |
|
|
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#train and test |
|
|
best_acc = None |
|
|
best_acc = None |
|
|
|
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
for epoch in range(num_epochs): |
|
@@ -59,9 +61,9 @@ for epoch in range(num_epochs): |
|
|
for i, (sents,labels) in enumerate(train_loader): |
|
|
for i, (sents,labels) in enumerate(train_loader): |
|
|
sents = Variable(sents) |
|
|
sents = Variable(sents) |
|
|
labels = Variable(labels) |
|
|
labels = Variable(labels) |
|
|
if cuda: |
|
|
|
|
|
sents = sents.cuda() |
|
|
|
|
|
labels = labels.cuda() |
|
|
|
|
|
|
|
|
if cuda: |
|
|
|
|
|
sents = sents.cuda() |
|
|
|
|
|
labels = labels.cuda() |
|
|
optimizer.zero_grad() |
|
|
optimizer.zero_grad() |
|
|
outputs = cnn(sents) |
|
|
outputs = cnn(sents) |
|
|
loss = criterion(outputs, labels) |
|
|
loss = criterion(outputs, labels) |
|
@@ -78,8 +80,8 @@ for epoch in range(num_epochs): |
|
|
total = 0 |
|
|
total = 0 |
|
|
for sents, labels in test_loader: |
|
|
for sents, labels in test_loader: |
|
|
sents = Variable(sents) |
|
|
sents = Variable(sents) |
|
|
if cuda: |
|
|
|
|
|
sents = sents.cuda() |
|
|
|
|
|
|
|
|
if cuda: |
|
|
|
|
|
sents = sents.cuda() |
|
|
labels = labels.cuda() |
|
|
labels = labels.cuda() |
|
|
outputs = cnn(sents) |
|
|
outputs = cnn(sents) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
@@ -90,8 +92,8 @@ for epoch in range(num_epochs): |
|
|
|
|
|
|
|
|
if best_acc is None or acc > best_acc: |
|
|
if best_acc is None or acc > best_acc: |
|
|
best_acc = acc |
|
|
best_acc = acc |
|
|
if os.path.exists("models") is False: |
|
|
|
|
|
os.makedirs("models") |
|
|
|
|
|
|
|
|
if os.path.exists("models") is False: |
|
|
|
|
|
os.makedirs("models") |
|
|
torch.save(cnn.state_dict(), 'models/cnn.pkl') |
|
|
torch.save(cnn.state_dict(), 'models/cnn.pkl') |
|
|
else: |
|
|
else: |
|
|
learning_rate = learning_rate * 0.8 |
|
|
learning_rate = learning_rate * 0.8 |
|
|