|
-
- import torch
- from torch.autograd import Variable
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class Highway(nn.Module):
- """Highway network"""
- def __init__(self, input_size):
- super(Highway, self).__init__()
- self.fc1 = nn.Linear(input_size, input_size, bias=True)
- self.fc2 = nn.Linear(input_size, input_size, bias=True)
-
- def forward(self, x):
- t = F.sigmoid(self.fc1(x))
- return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1-t, x)
-
-
- class charLM(nn.Module):
- """CNN + highway network + LSTM
- # Input:
- 4D tensor with shape [batch_size, in_channel, height, width]
- # Output:
- 2D Tensor with shape [batch_size, vocab_size]
- # Arguments:
- char_emb_dim: the size of each character's embedding
- word_emb_dim: the size of each word's embedding
- vocab_size: num of unique words
- num_char: num of characters
- use_gpu: True or False
- """
- def __init__(self, char_emb_dim, word_emb_dim,
- vocab_size, num_char, use_gpu):
- super(charLM, self).__init__()
- self.char_emb_dim = char_emb_dim
- self.word_emb_dim = word_emb_dim
- self.vocab_size = vocab_size
-
- # char embedding layer
- self.char_embed = nn.Embedding(num_char, char_emb_dim)
-
- # convolutions of filters with different sizes
- self.convolutions = []
-
- # list of tuples: (the number of filter, width)
- self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]
-
- for out_channel, filter_width in self.filter_num_width:
- self.convolutions.append(
- nn.Conv2d(
- 1, # in_channel
- out_channel, # out_channel
- kernel_size=(char_emb_dim, filter_width), # (height, width)
- bias=True
- )
- )
-
- self.highway_input_dim = sum([x for x, y in self.filter_num_width])
-
- self.batch_norm = nn.BatchNorm1d(self.highway_input_dim, affine=False)
-
- # highway net
- self.highway1 = Highway(self.highway_input_dim)
- self.highway2 = Highway(self.highway_input_dim)
-
- # LSTM
- self.lstm_num_layers = 2
-
- self.lstm = nn.LSTM(input_size=self.highway_input_dim,
- hidden_size=self.word_emb_dim,
- num_layers=self.lstm_num_layers,
- bias=True,
- dropout=0.5,
- batch_first=True)
-
- # output layer
- self.dropout = nn.Dropout(p=0.5)
- self.linear = nn.Linear(self.word_emb_dim, self.vocab_size)
-
-
- if use_gpu is True:
- for x in range(len(self.convolutions)):
- self.convolutions[x] = self.convolutions[x].cuda()
- self.highway1 = self.highway1.cuda()
- self.highway2 = self.highway2.cuda()
- self.lstm = self.lstm.cuda()
- self.dropout = self.dropout.cuda()
- self.char_embed = self.char_embed.cuda()
- self.linear = self.linear.cuda()
- self.batch_norm = self.batch_norm.cuda()
-
-
- def forward(self, x, hidden):
- # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2]
- # Return: Variable of Tensor with shape [num_words, len(word_dict)]
- lstm_batch_size = x.size()[0]
- lstm_seq_len = x.size()[1]
-
- x = x.contiguous().view(-1, x.size()[2])
- # [num_seq*seq_len, max_word_len+2]
-
- x = self.char_embed(x)
- # [num_seq*seq_len, max_word_len+2, char_emb_dim]
-
- x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3)
- # [num_seq*seq_len, 1, max_word_len+2, char_emb_dim]
-
- x = self.conv_layers(x)
- # [num_seq*seq_len, total_num_filters]
-
- x = self.batch_norm(x)
- # [num_seq*seq_len, total_num_filters]
-
- x = self.highway1(x)
- x = self.highway2(x)
- # [num_seq*seq_len, total_num_filters]
-
- x = x.contiguous().view(lstm_batch_size,lstm_seq_len, -1)
- # [num_seq, seq_len, total_num_filters]
-
- x, hidden = self.lstm(x, hidden)
- # [seq_len, num_seq, hidden_size]
-
- x = self.dropout(x)
- # [seq_len, num_seq, hidden_size]
-
- x = x.contiguous().view(lstm_batch_size*lstm_seq_len, -1)
- # [num_seq*seq_len, hidden_size]
-
- x = self.linear(x)
- # [num_seq*seq_len, vocab_size]
- return x, hidden
-
-
- def conv_layers(self, x):
- chosen_list = list()
- for conv in self.convolutions:
- feature_map = F.tanh(conv(x))
- # (batch_size, out_channel, 1, max_word_len-width+1)
- chosen = torch.max(feature_map, 3)[0]
- # (batch_size, out_channel, 1)
- chosen = chosen.squeeze()
- # (batch_size, out_channel)
- chosen_list.append(chosen)
-
- # (batch_size, total_num_filers)
- return torch.cat(chosen_list, 1)
|