|
@@ -2,6 +2,7 @@ import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
from .base_model import BaseModel |
|
|
from .base_model import BaseModel |
|
|
|
|
|
from ..core.const import Const |
|
|
from ..modules import decoder as Decoder |
|
|
from ..modules import decoder as Decoder |
|
|
from ..modules import encoder as Encoder |
|
|
from ..modules import encoder as Encoder |
|
|
from ..modules import aggregator as Aggregator |
|
|
from ..modules import aggregator as Aggregator |
|
@@ -40,7 +41,7 @@ class ESIM(BaseModel): |
|
|
(self.vocab_size, self.embed_dim), dropout=self.dropout, |
|
|
(self.vocab_size, self.embed_dim), dropout=self.dropout, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) |
|
|
|
|
|
|
|
|
self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) |
|
|
|
|
|
|
|
|
self.encoder = Encoder.LSTM( |
|
|
self.encoder = Encoder.LSTM( |
|
|
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
@@ -51,7 +52,7 @@ class ESIM(BaseModel): |
|
|
self.mean_pooling = Aggregator.MeanPoolWithMask() |
|
|
self.mean_pooling = Aggregator.MeanPoolWithMask() |
|
|
self.max_pooling = Aggregator.MaxPoolWithMask() |
|
|
self.max_pooling = Aggregator.MaxPoolWithMask() |
|
|
|
|
|
|
|
|
self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size) |
|
|
|
|
|
|
|
|
self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) |
|
|
|
|
|
|
|
|
self.decoder = Encoder.LSTM( |
|
|
self.decoder = Encoder.LSTM( |
|
|
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
@@ -60,12 +61,13 @@ class ESIM(BaseModel): |
|
|
|
|
|
|
|
|
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) |
|
|
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) |
|
|
|
|
|
|
|
|
def forward(self, words1, words2, seq_len1=None, seq_len2=None): |
|
|
|
|
|
|
|
|
def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): |
|
|
""" Forward function |
|
|
""" Forward function |
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
|
|
:param torch.LongTensor target: [B] 真实目标值 |
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -124,17 +126,23 @@ class ESIM(BaseModel): |
|
|
|
|
|
|
|
|
prediction = torch.tanh(self.output(v)) # prediction: [B, N] |
|
|
prediction = torch.tanh(self.output(v)) # prediction: [B, N] |
|
|
|
|
|
|
|
|
return {'pred': prediction} |
|
|
|
|
|
|
|
|
if target is not None: |
|
|
|
|
|
func = nn.CrossEntropyLoss() |
|
|
|
|
|
loss = func(prediction, target) |
|
|
|
|
|
return {Const.OUTPUT: prediction, Const.LOSS: loss} |
|
|
|
|
|
|
|
|
def predict(self, words1, words2, seq_len1, seq_len2): |
|
|
|
|
|
|
|
|
return {Const.OUTPUT: prediction} |
|
|
|
|
|
|
|
|
|
|
|
def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): |
|
|
""" Predict function |
|
|
""" Predict function |
|
|
|
|
|
|
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
|
|
:param torch.LongTensor target: [B] 真实目标值 |
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
""" |
|
|
""" |
|
|
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] |
|
|
|
|
|
return {'pred': torch.argmax(prediction, dim=-1)} |
|
|
|
|
|
|
|
|
prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] |
|
|
|
|
|
return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} |
|
|
|
|
|
|