|
|
@@ -4,149 +4,209 @@ __all__ = [ |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from .base_model import BaseModel |
|
|
|
from ..core.const import Const |
|
|
|
from ..modules import decoder as Decoder |
|
|
|
from ..modules import encoder as Encoder |
|
|
|
from ..modules import aggregator as Aggregator |
|
|
|
from ..core.utils import seq_len_to_mask |
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
|
|
my_inf = 10e12 |
|
|
|
from fastNLP.models import BaseModel |
|
|
|
from fastNLP.modules.encoder.embedding import TokenEmbedding |
|
|
|
from fastNLP.modules.encoder.lstm import LSTM |
|
|
|
from fastNLP.core.const import Const |
|
|
|
from fastNLP.core.utils import seq_len_to_mask |
|
|
|
|
|
|
|
|
|
|
|
class ESIM(BaseModel): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.models.ESIM` :class:`fastNLP.models.snli.ESIM` |
|
|
|
|
|
|
|
ESIM模型的一个PyTorch实现。 |
|
|
|
ESIM模型的论文: Enhanced LSTM for Natural Language Inference (arXiv: 1609.06038) |
|
|
|
"""ESIM model的一个PyTorch实现 |
|
|
|
论文参见: https://arxiv.org/pdf/1609.06038.pdf |
|
|
|
|
|
|
|
:param int vocab_size: 词表大小 |
|
|
|
:param int embed_dim: 词嵌入维度 |
|
|
|
:param int hidden_size: LSTM隐层大小 |
|
|
|
:param float dropout: dropout大小,默认为0 |
|
|
|
:param int num_classes: 标签数目,默认为3 |
|
|
|
:param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 |
|
|
|
:param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding |
|
|
|
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度 |
|
|
|
:param int num_labels: 目标标签种类数量,默认值为3 |
|
|
|
:param float dropout_rate: dropout的比率,默认值为0.3 |
|
|
|
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, |
|
|
|
dropout_embed=0.1): |
|
|
|
super(ESIM, self).__init__() |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.embed_dim = embed_dim |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.dropout = dropout |
|
|
|
self.n_labels = num_classes |
|
|
|
|
|
|
|
self.drop = nn.Dropout(self.dropout) |
|
|
|
|
|
|
|
self.embedding = Encoder.Embedding( |
|
|
|
(self.vocab_size, self.embed_dim), dropout=self.dropout, |
|
|
|
) |
|
|
|
|
|
|
|
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) |
|
|
|
|
|
|
|
self.encoder = Encoder.LSTM( |
|
|
|
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
|
batch_first=True, bidirectional=True |
|
|
|
) |
|
|
|
|
|
|
|
self.bi_attention = Aggregator.BiAttention() |
|
|
|
self.mean_pooling = Aggregator.AvgPoolWithMask() |
|
|
|
self.max_pooling = Aggregator.MaxPoolWithMask() |
|
|
|
|
|
|
|
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) |
|
|
|
|
|
|
|
self.decoder = Encoder.LSTM( |
|
|
|
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
|
batch_first=True, bidirectional=True |
|
|
|
) |
|
|
|
|
|
|
|
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) |
|
|
|
|
|
|
|
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): |
|
|
|
""" Forward function |
|
|
|
|
|
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
:param torch.LongTensor target: [B] 真实目标值 |
|
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
|
|
|
|
|
self.embedding = init_embedding |
|
|
|
self.dropout_embed = EmbedDropout(p=dropout_embed) |
|
|
|
if hidden_size is None: |
|
|
|
hidden_size = self.embedding.embed_size |
|
|
|
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) |
|
|
|
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) |
|
|
|
|
|
|
|
self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), |
|
|
|
nn.Linear(8 * hidden_size, hidden_size), |
|
|
|
nn.ReLU()) |
|
|
|
nn.init.xavier_uniform_(self.interfere[1].weight.data) |
|
|
|
self.bi_attention = SoftmaxAttention() |
|
|
|
|
|
|
|
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) |
|
|
|
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) |
|
|
|
|
|
|
|
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), |
|
|
|
nn.Linear(8 * hidden_size, hidden_size), |
|
|
|
nn.Tanh(), |
|
|
|
nn.Dropout(p=dropout_rate), |
|
|
|
nn.Linear(hidden_size, num_labels)) |
|
|
|
|
|
|
|
self.dropout_rnn = nn.Dropout(p=dropout_rate) |
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.classifier[1].weight.data) |
|
|
|
nn.init.xavier_uniform_(self.classifier[4].weight.data) |
|
|
|
|
|
|
|
def forward(self, words1, words2, seq_len1, seq_len2, target=None): |
|
|
|
""" |
|
|
|
|
|
|
|
premise0 = self.embedding_layer(self.embedding(words1)) |
|
|
|
hypothesis0 = self.embedding_layer(self.embedding(words2)) |
|
|
|
|
|
|
|
if seq_len1 is not None: |
|
|
|
seq_len1 = seq_len_to_mask(seq_len1) |
|
|
|
else: |
|
|
|
seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) |
|
|
|
seq_len1 = (seq_len1.long()).to(device=premise0.device) |
|
|
|
if seq_len2 is not None: |
|
|
|
seq_len2 = seq_len_to_mask(seq_len2) |
|
|
|
else: |
|
|
|
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) |
|
|
|
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) |
|
|
|
|
|
|
|
_BP, _PSL, _HP = premise0.size() |
|
|
|
_BH, _HSL, _HH = hypothesis0.size() |
|
|
|
_BPL, _PLL = seq_len1.size() |
|
|
|
_HPL, _HLL = seq_len2.size() |
|
|
|
|
|
|
|
assert _BP == _BH and _BPL == _HPL and _BP == _BPL |
|
|
|
assert _HP == _HH |
|
|
|
assert _PSL == _PLL and _HSL == _HLL |
|
|
|
|
|
|
|
B, PL, H = premise0.size() |
|
|
|
B, HL, H = hypothesis0.size() |
|
|
|
|
|
|
|
a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] |
|
|
|
b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] |
|
|
|
|
|
|
|
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] |
|
|
|
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] |
|
|
|
|
|
|
|
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) |
|
|
|
|
|
|
|
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] |
|
|
|
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] |
|
|
|
|
|
|
|
f_ma = self.inference_layer(ma) |
|
|
|
f_mb = self.inference_layer(mb) |
|
|
|
|
|
|
|
vat = self.decoder(self.drop(f_ma)) |
|
|
|
vbt = self.decoder(self.drop(f_mb)) |
|
|
|
|
|
|
|
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] |
|
|
|
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] |
|
|
|
|
|
|
|
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] |
|
|
|
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] |
|
|
|
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] |
|
|
|
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] |
|
|
|
|
|
|
|
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] |
|
|
|
|
|
|
|
prediction = torch.tanh(self.output(v)) # prediction: [B, N] |
|
|
|
|
|
|
|
if target is not None: |
|
|
|
func = nn.CrossEntropyLoss() |
|
|
|
loss = func(prediction, target) |
|
|
|
return {Const.OUTPUT: prediction, Const.LOSS: loss} |
|
|
|
|
|
|
|
return {Const.OUTPUT: prediction} |
|
|
|
|
|
|
|
def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): |
|
|
|
""" Predict function |
|
|
|
|
|
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
:param torch.LongTensor target: [B] 真实目标值 |
|
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
|
:param words1: [batch, seq_len] |
|
|
|
:param words2: [batch, seq_len] |
|
|
|
:param seq_len1: [batch] |
|
|
|
:param seq_len2: [batch] |
|
|
|
:param target: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] |
|
|
|
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} |
|
|
|
mask1 = seq_len_to_mask(seq_len1, words1.size(1)) |
|
|
|
mask2 = seq_len_to_mask(seq_len2, words2.size(1)) |
|
|
|
a0 = self.embedding(words1) # B * len * emb_dim |
|
|
|
b0 = self.embedding(words2) |
|
|
|
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) |
|
|
|
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] |
|
|
|
b = self.rnn(b0, mask2.byte()) |
|
|
|
# a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] |
|
|
|
# b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) |
|
|
|
|
|
|
|
ai, bi = self.bi_attention(a, mask1, b, mask2) |
|
|
|
|
|
|
|
a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] |
|
|
|
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) |
|
|
|
a_f = self.interfere(a_) |
|
|
|
b_f = self.interfere(b_) |
|
|
|
|
|
|
|
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] |
|
|
|
b_h = self.rnn_high(b_f, mask2.byte()) |
|
|
|
# a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] |
|
|
|
# b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) |
|
|
|
|
|
|
|
a_avg = self.mean_pooling(a_h, mask1, dim=1) |
|
|
|
a_max, _ = self.max_pooling(a_h, mask1, dim=1) |
|
|
|
b_avg = self.mean_pooling(b_h, mask2, dim=1) |
|
|
|
b_max, _ = self.max_pooling(b_h, mask2, dim=1) |
|
|
|
|
|
|
|
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] |
|
|
|
logits = torch.tanh(self.classifier(out)) |
|
|
|
|
|
|
|
if target is not None: |
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
loss = loss_fct(logits, target) |
|
|
|
|
|
|
|
return {Const.LOSS: loss, Const.OUTPUT: logits} |
|
|
|
else: |
|
|
|
return {Const.OUTPUT: logits} |
|
|
|
|
|
|
|
def predict(self, **kwargs): |
|
|
|
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) |
|
|
|
return {Const.OUTPUT: pred} |
|
|
|
|
|
|
|
# input [batch_size, len , hidden] |
|
|
|
# mask [batch_size, len] (111...00) |
|
|
|
@staticmethod |
|
|
|
def mean_pooling(input, mask, dim=1): |
|
|
|
masks = mask.view(mask.size(0), mask.size(1), -1).float() |
|
|
|
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def max_pooling(input, mask, dim=1): |
|
|
|
my_inf = 10e12 |
|
|
|
masks = mask.view(mask.size(0), mask.size(1), -1) |
|
|
|
masks = masks.expand(-1, -1, input.size(2)).float() |
|
|
|
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) |
|
|
|
|
|
|
|
|
|
|
|
class EmbedDropout(nn.Dropout): |
|
|
|
|
|
|
|
def forward(self, sequences_batch): |
|
|
|
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) |
|
|
|
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) |
|
|
|
return dropout_mask.unsqueeze(1) * sequences_batch |
|
|
|
|
|
|
|
|
|
|
|
class BiRNN(nn.Module): |
|
|
|
def __init__(self, input_size, hidden_size, dropout_rate=0.3): |
|
|
|
super(BiRNN, self).__init__() |
|
|
|
self.dropout_rate = dropout_rate |
|
|
|
self.rnn = nn.LSTM(input_size, hidden_size, |
|
|
|
num_layers=1, |
|
|
|
bidirectional=True, |
|
|
|
batch_first=True) |
|
|
|
|
|
|
|
def forward(self, x, x_mask): |
|
|
|
# Sort x |
|
|
|
lengths = x_mask.data.eq(1).long().sum(1) |
|
|
|
_, idx_sort = torch.sort(lengths, dim=0, descending=True) |
|
|
|
_, idx_unsort = torch.sort(idx_sort, dim=0) |
|
|
|
lengths = list(lengths[idx_sort]) |
|
|
|
|
|
|
|
x = x.index_select(0, idx_sort) |
|
|
|
# Pack it up |
|
|
|
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) |
|
|
|
# Apply dropout to input |
|
|
|
if self.dropout_rate > 0: |
|
|
|
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) |
|
|
|
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) |
|
|
|
output = self.rnn(rnn_input)[0] |
|
|
|
# Unpack everything |
|
|
|
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] |
|
|
|
output = output.index_select(0, idx_unsort) |
|
|
|
if output.size(1) != x_mask.size(1): |
|
|
|
padding = torch.zeros(output.size(0), |
|
|
|
x_mask.size(1) - output.size(1), |
|
|
|
output.size(2)).type(output.data.type()) |
|
|
|
output = torch.cat([output, padding], 1) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def masked_softmax(tensor, mask): |
|
|
|
tensor_shape = tensor.size() |
|
|
|
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) |
|
|
|
|
|
|
|
# Reshape the mask so it matches the size of the input tensor. |
|
|
|
while mask.dim() < tensor.dim(): |
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
mask = mask.expand_as(tensor).contiguous().float() |
|
|
|
reshaped_mask = mask.view(-1, mask.size()[-1]) |
|
|
|
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) |
|
|
|
result = result * reshaped_mask |
|
|
|
# 1e-13 is added to avoid divisions by zero. |
|
|
|
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) |
|
|
|
return result.view(*tensor_shape) |
|
|
|
|
|
|
|
|
|
|
|
def weighted_sum(tensor, weights, mask): |
|
|
|
w_sum = weights.bmm(tensor) |
|
|
|
while mask.dim() < w_sum.dim(): |
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
mask = mask.transpose(-1, -2) |
|
|
|
mask = mask.expand_as(w_sum).contiguous().float() |
|
|
|
return w_sum * mask |
|
|
|
|
|
|
|
|
|
|
|
class SoftmaxAttention(nn.Module): |
|
|
|
|
|
|
|
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): |
|
|
|
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) |
|
|
|
.contiguous()) |
|
|
|
|
|
|
|
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) |
|
|
|
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) |
|
|
|
.contiguous(), |
|
|
|
premise_mask) |
|
|
|
|
|
|
|
attended_premises = weighted_sum(hypothesis_batch, |
|
|
|
prem_hyp_attn, |
|
|
|
premise_mask) |
|
|
|
attended_hypotheses = weighted_sum(premise_batch, |
|
|
|
hyp_prem_attn, |
|
|
|
hypothesis_mask) |
|
|
|
|
|
|
|
return attended_premises, attended_hypotheses |