|
|
@@ -22,19 +22,16 @@ class HAN(nn.Module): |
|
|
|
self.output_layer = nn.Linear(2* sent_hidden_size, output_size) |
|
|
|
self.softmax = nn.Softmax() |
|
|
|
|
|
|
|
def forward(self, x, level='w'): |
|
|
|
def forward(self, doc): |
|
|
|
# input is a sequence of vector |
|
|
|
# if level == w, a seq of words (a sent); level == s, a seq of sents (a doc) |
|
|
|
if level == 's': |
|
|
|
v = self.sent_layer(x) |
|
|
|
output = self.softmax(self.output_layer(v)) |
|
|
|
return output |
|
|
|
elif level == 'w': |
|
|
|
s = self.word_layer(x) |
|
|
|
return s |
|
|
|
else: |
|
|
|
print('unknow level in Parameter!') |
|
|
|
|
|
|
|
s_list = [] |
|
|
|
for sent in doc: |
|
|
|
s_list.append(self.word_layer(sent)) |
|
|
|
s_vec = torch.cat(s_list, dim=1).t() |
|
|
|
doc_vec = self.sent_layer(s_vec) |
|
|
|
output = self.softmax(self.output_layer(doc_vec)) |
|
|
|
return output |
|
|
|
|
|
|
|
class AttentionNet(nn.Module): |
|
|
|
def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size): |
|
|
@@ -60,11 +57,53 @@ class AttentionNet(nn.Module): |
|
|
|
self.context_vec.data.uniform_(-0.1, 0.1) |
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
# inputs's dim seq_len*word_dim |
|
|
|
# inputs's dim (seq_len, word_dim) |
|
|
|
inputs = torch.unsqueeze(inputs, 1) |
|
|
|
h_t, hidden = self.gru(inputs) |
|
|
|
h_t = torch.squeeze(h_t, 1) |
|
|
|
u = self.tanh(self.fc(h_t)) |
|
|
|
alpha = self.softmax(torch.mm(u, self.context_vec)) |
|
|
|
output = torch.mm(h_t.t(), alpha) |
|
|
|
# output's dim (2*hidden_size, 1) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
|
Train process |
|
|
|
''' |
|
|
|
import math |
|
|
|
import os |
|
|
|
import copy |
|
|
|
import pickle |
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import matplotlib.ticker as ticker |
|
|
|
import numpy as np |
|
|
|
import json |
|
|
|
import nltk |
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(lr=0.01) |
|
|
|
criterion = nn.NLLLoss() |
|
|
|
epoch = 1 |
|
|
|
batch_size = 10 |
|
|
|
|
|
|
|
net = HAN(input_size=100, output_size=5, |
|
|
|
word_hidden_size=50, word_num_layers=1, word_context_size=100, |
|
|
|
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100) |
|
|
|
|
|
|
|
def dataloader(filename): |
|
|
|
samples = pickle.load(open(filename, 'rb')) |
|
|
|
return samples |
|
|
|
|
|
|
|
def gen_doc(text): |
|
|
|
pass |
|
|
|
|
|
|
|
class SampleDoc: |
|
|
|
def __init__(self, doc, label): |
|
|
|
self.doc = doc |
|
|
|
self.label = label |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for sent in self.doc: |
|
|
|
for word in sent: |
|
|
|
|