|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- import torch
- import torch.nn.functional as F
- from torch import nn
-
- from utils import get_length, INF
-
-
- class Mask(nn.Module):
- def forward(self, seq, mask):
- # seq: (N, C, L)
- # mask: (N, L)
- seq_mask = torch.unsqueeze(mask, 2)
- seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
- return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
-
-
- def __str__(self):
- return 'Mask'
-
-
- class BatchNorm(nn.Module):
- def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
- super(BatchNorm, self).__init__()
- self.mask_opt = Mask()
- self.pre_mask = pre_mask
- self.post_mask = post_mask
- self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
-
- def forward(self, seq, mask):
- if self.pre_mask:
- seq = self.mask_opt(seq, mask)
- seq = self.bn(seq)
- if self.post_mask:
- seq = self.mask_opt(seq, mask)
- return seq
-
- def __str__(self):
- return 'BatchNorm'
-
- class ConvBN(nn.Module):
- def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
- pre_mask, post_mask, with_bn=True, with_relu=True):
- super(ConvBN, self).__init__()
- self.mask_opt = Mask()
- self.pre_mask = pre_mask
- self.post_mask = post_mask
- self.with_bn = with_bn
- self.with_relu = with_relu
- self.kernal_size = kernal_size
- self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
- self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
-
- if with_bn:
- self.bn = BatchNorm(out_channels, not post_mask, True)
-
- if with_relu:
- self.relu = nn.ReLU()
-
- def forward(self, seq, mask):
- if self.pre_mask:
- seq = self.mask_opt(seq, mask)
- seq = self.conv(seq)
- if self.post_mask:
- seq = self.mask_opt(seq, mask)
- if self.with_bn:
- seq = self.bn(seq, mask)
- if self.with_relu:
- seq = self.relu(seq)
- seq = self.dropout(seq)
- return seq
-
- def __str__(self):
- return 'ConvBN_{}'.format(self.kernal_size)
-
-
- class AvgPool(nn.Module):
- def __init__(self, kernal_size, pre_mask, post_mask):
- super(AvgPool, self).__init__()
- self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
- self.pre_mask = pre_mask
- self.post_mask = post_mask
- self.mask_opt = Mask()
- self.kernal_size = kernal_size
- def forward(self, seq, mask):
- if self.pre_mask:
- seq = self.mask_opt(seq, mask)
- seq = self.avg_pool(seq)
- if self.post_mask:
- seq = self.mask_opt(seq, mask)
- return seq
-
- def __str__(self):
- return 'AvgPool{}'.format(self.kernal_size)
-
-
- class MaxPool(nn.Module):
- def __init__(self, kernal_size, pre_mask, post_mask):
- super(MaxPool, self).__init__()
- self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
- self.pre_mask = pre_mask
- self.post_mask = post_mask
- self.mask_opt = Mask()
- self.kernel_size = kernal_size
-
- def forward(self, seq, mask):
- if self.pre_mask:
- seq = self.mask_opt(seq, mask)
- seq = self.max_pool(seq)
- if self.post_mask:
- seq = self.mask_opt(seq, mask)
- return seq
-
- def __str__(self):
- return 'MaxPool{}'.format(self.kernel_size)
-
-
- class Attention(nn.Module):
- def __init__(self, num_units, num_heads, keep_prob, is_mask):
- super(Attention, self).__init__()
- self.num_heads = num_heads
- self.keep_prob = keep_prob
-
- self.linear_q = nn.Linear(num_units, num_units)
- self.linear_k = nn.Linear(num_units, num_units)
- self.linear_v = nn.Linear(num_units, num_units)
-
- self.bn = BatchNorm(num_units, True, is_mask)
- self.dropout = nn.Dropout(p=1 - self.keep_prob)
-
- def forward(self, seq, mask):
- in_c = seq.size()[1]
- seq = torch.transpose(seq, 1, 2) # (N, L, C)
- queries = seq
- keys = seq
- num_heads = self.num_heads
-
- # T_q = T_k = L
- Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
- K = F.relu(self.linear_k(seq)) # (N, T_k, C)
- V = F.relu(self.linear_v(seq)) # (N, T_k, C)
-
- # Split and concat
- Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
- K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
- V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
-
- # Multiplication
- outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
- # Scale
- outputs = outputs / (K_.size()[-1] ** 0.5)
- # Key Masking
- key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
- key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
- key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
-
- paddings = torch.ones_like(outputs) * (-INF) # extremely small value
- outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
-
- query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
- query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
- query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
-
- att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
- att_scores = self.dropout(att_scores)
-
- # Weighted sum
- x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
- # Restore shape
- x_outputs = torch.cat(
- torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
- dim=2) # (N, T_q, C)
-
- x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
- x = self.bn(x, mask)
-
- return x
-
- def __str__(self):
- return 'Attention'
-
-
- class RNN(nn.Module):
- def __init__(self, hidden_size, output_keep_prob):
- super(RNN, self).__init__()
- self.hidden_size = hidden_size
- self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
- self.output_keep_prob = output_keep_prob
-
- self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
-
- def forward(self, seq, mask):
- # seq: (N, C, L)
- # mask: (N, L)
- max_len = seq.size()[2]
- length = get_length(mask)
- seq = torch.transpose(seq, 1, 2) # to (N, L, C)
- packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
- enforce_sorted=False)
- outputs, _ = self.bid_rnn(packed_seq)
- outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
- total_length=max_len)[0]
- outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
- outputs = self.out_dropout(outputs) # output dropout
- return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
-
- def __str__(self):
- return 'RNN'
-
-
- class LinearCombine(nn.Module):
- def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
- super(LinearCombine, self).__init__()
- self.input_aware = input_aware
- self.word_level = word_level
-
- if input_aware:
- raise NotImplementedError("Input aware is not supported.")
- self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
- requires_grad=trainable)
-
- def forward(self, seq):
- nw = F.softmax(self.w, dim=0)
- seq = torch.mul(seq, nw)
- seq = torch.sum(seq, dim=0)
- return seq
|