|
- import torch.nn as nn
- import torch
- from fastNLP.core.utils import seq_len_to_mask
- from utils import better_init_rnn
- import numpy as np
-
-
- class WordLSTMCell_yangjie(nn.Module):
-
- """A basic LSTM cell."""
-
- def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True):
- """
- Most parts are copied from torch.nn.LSTMCell.
- """
-
- super().__init__()
- self.left2right = left2right
- self.debug = debug
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.use_bias = use_bias
- self.weight_ih = nn.Parameter(
- torch.FloatTensor(input_size, 3 * hidden_size))
- self.weight_hh = nn.Parameter(
- torch.FloatTensor(hidden_size, 3 * hidden_size))
- if use_bias:
- self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
- else:
- self.register_parameter('bias', None)
- self.reset_parameters()
-
- def reset_parameters(self):
- """
- Initialize parameters following the way proposed in the paper.
- """
- nn.init.orthogonal(self.weight_ih.data)
- weight_hh_data = torch.eye(self.hidden_size)
- weight_hh_data = weight_hh_data.repeat(1, 3)
- with torch.no_grad():
- self.weight_hh.set_(weight_hh_data)
- # The bias is just set to zero vectors.
- if self.use_bias:
- nn.init.constant(self.bias.data, val=0)
-
- def forward(self, input_, hx):
- """
- Args:
- input_: A (batch, input_size) tensor containing input
- features.
- hx: A tuple (h_0, c_0), which contains the initial hidden
- and cell state, where the size of both states is
- (batch, hidden_size).
- Returns:
- h_1, c_1: Tensors containing the next hidden and cell state.
- """
-
- h_0, c_0 = hx
-
-
-
- batch_size = h_0.size(0)
- bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
- wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
- wi = torch.mm(input_, self.weight_ih)
- f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
- c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
-
- return c_1
-
- def __repr__(self):
- s = '{name}({input_size}, {hidden_size})'
- return s.format(name=self.__class__.__name__, **self.__dict__)
-
-
- class MultiInputLSTMCell_V0(nn.Module):
- def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
- super().__init__()
- self.char_input_size = char_input_size
- self.hidden_size = hidden_size
- self.use_bias = use_bias
-
- self.weight_ih = nn.Parameter(
- torch.FloatTensor(char_input_size, 3 * hidden_size)
- )
-
- self.weight_hh = nn.Parameter(
- torch.FloatTensor(hidden_size, 3 * hidden_size)
- )
-
- self.alpha_weight_ih = nn.Parameter(
- torch.FloatTensor(char_input_size, hidden_size)
- )
-
- self.alpha_weight_hh = nn.Parameter(
- torch.FloatTensor(hidden_size, hidden_size)
- )
-
- if self.use_bias:
- self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
- self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
- else:
- self.register_parameter('bias', None)
- self.register_parameter('alpha_bias', None)
-
- self.debug = debug
- self.reset_parameters()
-
- def reset_parameters(self):
- """
- Initialize parameters following the way proposed in the paper.
- """
- nn.init.orthogonal(self.weight_ih.data)
- nn.init.orthogonal(self.alpha_weight_ih.data)
-
- weight_hh_data = torch.eye(self.hidden_size)
- weight_hh_data = weight_hh_data.repeat(1, 3)
- with torch.no_grad():
- self.weight_hh.set_(weight_hh_data)
-
- alpha_weight_hh_data = torch.eye(self.hidden_size)
- alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
- with torch.no_grad():
- self.alpha_weight_hh.set_(alpha_weight_hh_data)
-
- # The bias is just set to zero vectors.
- if self.use_bias:
- nn.init.constant_(self.bias.data, val=0)
- nn.init.constant_(self.alpha_bias.data, val=0)
-
- def forward(self, inp, skip_c, skip_count, hx):
- '''
-
- :param inp: chars B * hidden
- :param skip_c: 由跳边得到的c, B * X * hidden
- :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
- :param hx:
- :return:
- '''
- max_skip_count = torch.max(skip_count).item()
-
-
-
- if True:
- h_0, c_0 = hx
- batch_size = h_0.size(0)
-
- bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
-
- wi = torch.matmul(inp, self.weight_ih)
- wh = torch.matmul(h_0, self.weight_hh)
-
-
-
- i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
-
- i = torch.sigmoid(i).unsqueeze(1)
- o = torch.sigmoid(o).unsqueeze(1)
- g = torch.tanh(g).unsqueeze(1)
-
-
-
- alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
- alpha_wi.unsqueeze_(1)
-
- # alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size)
- alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)
-
- alpha_bias_batch = self.alpha_bias.unsqueeze(0)
-
- alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)
-
- skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])
-
- skip_mask = 1 - skip_mask
-
-
- skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)
-
- skip_mask = (skip_mask).float()*1e20
-
- alpha = alpha - skip_mask
-
- alpha = torch.exp(torch.cat([i, alpha], dim=1))
-
-
-
- alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
-
- alpha = torch.div(alpha, alpha_sum)
-
- merge_i_c = torch.cat([g, skip_c], dim=1)
-
- c_1 = merge_i_c * alpha
-
- c_1 = c_1.sum(1, keepdim=True)
- # h_1 = o * c_1
- h_1 = o * torch.tanh(c_1)
-
- return h_1.squeeze(1), c_1.squeeze(1)
-
- else:
-
- h_0, c_0 = hx
- batch_size = h_0.size(0)
-
- bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
-
- wi = torch.matmul(inp, self.weight_ih)
- wh = torch.matmul(h_0, self.weight_hh)
-
- i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
-
- i = torch.sigmoid(i).unsqueeze(1)
- o = torch.sigmoid(o).unsqueeze(1)
- g = torch.tanh(g).unsqueeze(1)
-
- c_1 = g
- h_1 = o * c_1
-
- return h_1,c_1
-
- class MultiInputLSTMCell_V1(nn.Module):
- def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
- super().__init__()
- self.char_input_size = char_input_size
- self.hidden_size = hidden_size
- self.use_bias = use_bias
-
- self.weight_ih = nn.Parameter(
- torch.FloatTensor(char_input_size, 3 * hidden_size)
- )
-
- self.weight_hh = nn.Parameter(
- torch.FloatTensor(hidden_size, 3 * hidden_size)
- )
-
- self.alpha_weight_ih = nn.Parameter(
- torch.FloatTensor(char_input_size, hidden_size)
- )
-
- self.alpha_weight_hh = nn.Parameter(
- torch.FloatTensor(hidden_size, hidden_size)
- )
-
- if self.use_bias:
- self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
- self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
- else:
- self.register_parameter('bias', None)
- self.register_parameter('alpha_bias', None)
-
- self.debug = debug
- self.reset_parameters()
-
- def reset_parameters(self):
- """
- Initialize parameters following the way proposed in the paper.
- """
- nn.init.orthogonal(self.weight_ih.data)
- nn.init.orthogonal(self.alpha_weight_ih.data)
-
- weight_hh_data = torch.eye(self.hidden_size)
- weight_hh_data = weight_hh_data.repeat(1, 3)
- with torch.no_grad():
- self.weight_hh.set_(weight_hh_data)
-
- alpha_weight_hh_data = torch.eye(self.hidden_size)
- alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
- with torch.no_grad():
- self.alpha_weight_hh.set_(alpha_weight_hh_data)
-
- # The bias is just set to zero vectors.
- if self.use_bias:
- nn.init.constant_(self.bias.data, val=0)
- nn.init.constant_(self.alpha_bias.data, val=0)
-
- def forward(self, inp, skip_c, skip_count, hx):
- '''
-
- :param inp: chars B * hidden
- :param skip_c: 由跳边得到的c, B * X * hidden
- :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
- :param hx:
- :return:
- '''
- max_skip_count = torch.max(skip_count).item()
-
-
-
- if True:
- h_0, c_0 = hx
- batch_size = h_0.size(0)
-
- bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
-
- wi = torch.matmul(inp, self.weight_ih)
- wh = torch.matmul(h_0, self.weight_hh)
-
-
- i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)
-
- i = torch.sigmoid(i).unsqueeze(1)
- o = torch.sigmoid(o).unsqueeze(1)
- g = torch.tanh(g).unsqueeze(1)
-
-
-
- ##basic lstm start
-
- f = 1 - i
- c_1_basic = f*c_0.unsqueeze(1) + i*g
- c_1_basic = c_1_basic.squeeze(1)
-
-
-
-
-
- alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
- alpha_wi.unsqueeze_(1)
-
-
- alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)
-
- alpha_bias_batch = self.alpha_bias.unsqueeze(0)
-
- alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)
-
- skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]).float()
-
- skip_mask = 1 - skip_mask
-
-
- skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)
-
- skip_mask = (skip_mask).float()*1e20
-
- alpha = alpha - skip_mask
-
- alpha = torch.exp(torch.cat([i, alpha], dim=1))
-
-
-
- alpha_sum = torch.sum(alpha, dim=1, keepdim=True)
-
- alpha = torch.div(alpha, alpha_sum)
-
- merge_i_c = torch.cat([g, skip_c], dim=1)
-
- c_1 = merge_i_c * alpha
-
- c_1 = c_1.sum(1, keepdim=True)
- # h_1 = o * c_1
- c_1 = c_1.squeeze(1)
- count_select = (skip_count != 0).float().unsqueeze(-1)
-
-
-
-
- c_1 = c_1*count_select + c_1_basic*(1-count_select)
-
-
- o = o.squeeze(1)
- h_1 = o * torch.tanh(c_1)
-
- return h_1, c_1
-
- class LatticeLSTMLayer_sup_back_V0(nn.Module):
- def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
- bias=True,device=None,debug=False,skip_before_head=False):
- super().__init__()
-
- self.skip_before_head = skip_before_head
-
- self.hidden_size = hidden_size
-
- self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug)
-
- self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)
-
- self.word_input_size = word_input_size
- self.left2right = left2right
- self.bias = bias
- self.device = device
- self.debug = debug
-
- def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
- '''
-
- :param inp: batch * seq_len * embedding, chars
- :param seq_len: batch, length of chars
- :param skip_sources: batch * seq_len * X, 跳边的起点
- :param skip_words: batch * seq_len * X * embedding, 跳边的词
- :param lexicon_count: batch * seq_len, count of lexicon per example per position
- :param init_state: the hx of rnn
- :return:
- '''
-
-
- if self.left2right:
-
- max_seq_len = max(seq_len)
- batch_size = inp.size(0)
- c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
- h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
-
- for i in range(max_seq_len):
- max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
- h_0, c_0 = h_[:, i, :], c_[:, i, :]
-
- skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
-
- skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
- skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
-
-
- index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
- index_1 = skip_source_flat
-
- if not self.skip_before_head:
- c_x = c_[[index_0, index_1+1]]
- h_x = h_[[index_0, index_1+1]]
- else:
- c_x = c_[[index_0,index_1]]
- h_x = h_[[index_0,index_1]]
-
- c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
- h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
-
-
-
-
- c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
-
- c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
-
- h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
-
-
- h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
- c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)
-
- return h_[:,1:],c_[:,1:]
- else:
- mask_for_seq_len = seq_len_to_mask(seq_len)
-
- max_seq_len = max(seq_len)
- batch_size = inp.size(0)
- c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
- h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
-
- for i in reversed(range(max_seq_len)):
- max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
-
-
-
- h_0, c_0 = h_[:, 0, :], c_[:, 0, :]
-
- skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
-
- skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
- skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
-
-
- index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
- index_1 = skip_source_flat-i
-
- if not self.skip_before_head:
- c_x = c_[[index_0, index_1-1]]
- h_x = h_[[index_0, index_1-1]]
- else:
- c_x = c_[[index_0,index_1]]
- h_x = h_[[index_0,index_1]]
-
- c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
- h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
-
-
-
-
- c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
-
- c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
-
- h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
-
-
- h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
- c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)
-
-
- h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
- c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)
-
- return h_[:,:-1],c_[:,:-1]
-
- class LatticeLSTMLayer_sup_back_V1(nn.Module):
- # V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式,
- # 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同
- def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
- bias=True,device=None,debug=False,skip_before_head=False):
- super().__init__()
-
- self.debug = debug
-
- self.skip_before_head = skip_before_head
-
- self.hidden_size = hidden_size
-
- self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug)
-
- self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)
-
- self.word_input_size = word_input_size
- self.left2right = left2right
- self.bias = bias
- self.device = device
-
- def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
- '''
-
- :param inp: batch * seq_len * embedding, chars
- :param seq_len: batch, length of chars
- :param skip_sources: batch * seq_len * X, 跳边的起点
- :param skip_words: batch * seq_len * X * embedding_size, 跳边的词
- :param lexicon_count: batch * seq_len,
- lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量
- :param init_state: the hx of rnn
- :return:
- '''
-
-
- if self.left2right:
-
- max_seq_len = max(seq_len)
- batch_size = inp.size(0)
- c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
- h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
-
- for i in range(max_seq_len):
- max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
- h_0, c_0 = h_[:, i, :], c_[:, i, :]
-
- #为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量
- #为了匹配pytorch的[]取址方式,需要将reshape成二维张量
-
- skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
-
- skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
- skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
-
-
- index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
- index_1 = skip_source_flat
-
-
- if not self.skip_before_head:
- c_x = c_[[index_0, index_1+1]]
- h_x = h_[[index_0, index_1+1]]
- else:
- c_x = c_[[index_0,index_1]]
- h_x = h_[[index_0,index_1]]
-
- c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
- h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
-
-
-
- c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
-
- c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
-
- h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
-
-
- h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
- c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)
-
- return h_[:,1:],c_[:,1:]
- else:
- mask_for_seq_len = seq_len_to_mask(seq_len)
-
- max_seq_len = max(seq_len)
- batch_size = inp.size(0)
- c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
- h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
-
- for i in reversed(range(max_seq_len)):
- max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
-
-
- h_0, c_0 = h_[:, 0, :], c_[:, 0, :]
-
- skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()
-
- skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
- skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)
-
-
- index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
- index_1 = skip_source_flat-i
-
- if not self.skip_before_head:
- c_x = c_[[index_0, index_1-1]]
- h_x = h_[[index_0, index_1-1]]
- else:
- c_x = c_[[index_0,index_1]]
- h_x = h_[[index_0,index_1]]
-
- c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
- h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)
-
-
-
-
- c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))
-
-
-
- c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)
-
- h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))
-
-
- h_1_mask = h_1.masked_fill(~ mask_for_seq_len[:,i].unsqueeze(-1),0)
- c_1_mask = c_1.masked_fill(~ mask_for_seq_len[:, i].unsqueeze(-1), 0)
-
-
- h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
- c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)
-
-
-
- return h_[:,:-1],c_[:,:-1]
-
-
-
|