|
|
@@ -1,6 +1,5 @@ |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from fastNLP.models.base_model import BaseModel |
|
|
|
from fastNLP.modules import decoder as Decoder |
|
|
@@ -40,7 +39,7 @@ class ESIM(BaseModel): |
|
|
|
batch_first=self.batch_first, bidirectional=True |
|
|
|
) |
|
|
|
|
|
|
|
self.bi_attention = Aggregator.Bi_Attention() |
|
|
|
self.bi_attention = Aggregator.BiAttention() |
|
|
|
self.mean_pooling = Aggregator.MeanPoolWithMask() |
|
|
|
self.max_pooling = Aggregator.MaxPoolWithMask() |
|
|
|
|
|
|
@@ -53,23 +52,23 @@ class ESIM(BaseModel): |
|
|
|
|
|
|
|
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) |
|
|
|
|
|
|
|
def forward(self, premise, hypothesis, premise_len, hypothesis_len): |
|
|
|
def forward(self, words1, words2, seq_len1, seq_len2): |
|
|
|
""" Forward function |
|
|
|
|
|
|
|
:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. |
|
|
|
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. |
|
|
|
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. |
|
|
|
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. |
|
|
|
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. |
|
|
|
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. |
|
|
|
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. |
|
|
|
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. |
|
|
|
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. |
|
|
|
""" |
|
|
|
|
|
|
|
premise0 = self.embedding_layer(self.embedding(premise)) |
|
|
|
hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) |
|
|
|
premise0 = self.embedding_layer(self.embedding(words1)) |
|
|
|
hypothesis0 = self.embedding_layer(self.embedding(words2)) |
|
|
|
|
|
|
|
_BP, _PSL, _HP = premise0.size() |
|
|
|
_BH, _HSL, _HH = hypothesis0.size() |
|
|
|
_BPL, _PLL = premise_len.size() |
|
|
|
_HPL, _HLL = hypothesis_len.size() |
|
|
|
_BPL, _PLL = seq_len1.size() |
|
|
|
_HPL, _HLL = seq_len2.size() |
|
|
|
|
|
|
|
assert _BP == _BH and _BPL == _HPL and _BP == _BPL |
|
|
|
assert _HP == _HH |
|
|
@@ -84,7 +83,7 @@ class ESIM(BaseModel): |
|
|
|
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] |
|
|
|
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] |
|
|
|
|
|
|
|
ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) |
|
|
|
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) |
|
|
|
|
|
|
|
ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] |
|
|
|
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] |
|
|
@@ -98,17 +97,18 @@ class ESIM(BaseModel): |
|
|
|
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] |
|
|
|
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] |
|
|
|
|
|
|
|
va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] |
|
|
|
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] |
|
|
|
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] |
|
|
|
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] |
|
|
|
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] |
|
|
|
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] |
|
|
|
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] |
|
|
|
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] |
|
|
|
|
|
|
|
v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] |
|
|
|
|
|
|
|
prediction = F.tanh(self.output(v)) # prediction: [B, N] |
|
|
|
prediction = torch.tanh(self.output(v)) # prediction: [B, N] |
|
|
|
|
|
|
|
return {'pred': prediction} |
|
|
|
|
|
|
|
def predict(self, premise, hypothesis, premise_len, hypothesis_len): |
|
|
|
return self.forward(premise, hypothesis, premise_len, hypothesis_len) |
|
|
|
def predict(self, words1, words2, seq_len1, seq_len2): |
|
|
|
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] |
|
|
|
return torch.argmax(prediction, dim=-1) |
|
|
|
|