|
|
@@ -5,38 +5,45 @@ from fastNLP.models.base_model import BaseModel |
|
|
|
from fastNLP.modules import decoder as Decoder |
|
|
|
from fastNLP.modules import encoder as Encoder |
|
|
|
from fastNLP.modules import aggregator as Aggregator |
|
|
|
from fastNLP.modules.utils import seq_mask |
|
|
|
|
|
|
|
|
|
|
|
my_inf = 10e12 |
|
|
|
|
|
|
|
|
|
|
|
class ESIM(BaseModel): |
|
|
|
""" |
|
|
|
PyTorch Network for SNLI task using ESIM model. |
|
|
|
"""ESIM模型的一个PyTorch实现。 |
|
|
|
ESIM模型的论文: Enhanced LSTM for Natural Language Inference (arXiv: 1609.06038) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): |
|
|
|
""" |
|
|
|
:param int vocab_size: 词表大小 |
|
|
|
:param int embed_dim: 词嵌入维度 |
|
|
|
:param int hidden_size: LSTM隐层大小 |
|
|
|
:param float dropout: dropout大小,默认为0 |
|
|
|
:param int num_classes: 标签数目,默认为3 |
|
|
|
:param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 |
|
|
|
""" |
|
|
|
super(ESIM, self).__init__() |
|
|
|
self.vocab_size = kwargs["vocab_size"] |
|
|
|
self.embed_dim = kwargs["embed_dim"] |
|
|
|
self.hidden_size = kwargs["hidden_size"] |
|
|
|
self.batch_first = kwargs["batch_first"] |
|
|
|
self.dropout = kwargs["dropout"] |
|
|
|
self.n_labels = kwargs["num_classes"] |
|
|
|
self.gpu = kwargs["gpu"] and torch.cuda.is_available() |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.embed_dim = embed_dim |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.dropout = dropout |
|
|
|
self.n_labels = num_classes |
|
|
|
|
|
|
|
self.drop = nn.Dropout(self.dropout) |
|
|
|
|
|
|
|
self.embedding = Encoder.Embedding( |
|
|
|
self.vocab_size, self.embed_dim, dropout=self.dropout, |
|
|
|
init_emb=kwargs["init_embedding"] if "inin_embedding" in kwargs.keys() else None, |
|
|
|
init_emb=init_embedding, |
|
|
|
) |
|
|
|
|
|
|
|
self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) |
|
|
|
|
|
|
|
self.encoder = Encoder.LSTM( |
|
|
|
input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
|
batch_first=self.batch_first, bidirectional=True |
|
|
|
batch_first=True, bidirectional=True |
|
|
|
) |
|
|
|
|
|
|
|
self.bi_attention = Aggregator.BiAttention() |
|
|
@@ -47,24 +54,34 @@ class ESIM(BaseModel): |
|
|
|
|
|
|
|
self.decoder = Encoder.LSTM( |
|
|
|
input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, |
|
|
|
batch_first=self.batch_first, bidirectional=True |
|
|
|
batch_first=True, bidirectional=True |
|
|
|
) |
|
|
|
|
|
|
|
self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) |
|
|
|
|
|
|
|
def forward(self, words1, words2, seq_len1, seq_len2): |
|
|
|
def forward(self, words1, words2, seq_len1=None, seq_len2=None): |
|
|
|
""" Forward function |
|
|
|
|
|
|
|
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. |
|
|
|
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. |
|
|
|
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. |
|
|
|
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. |
|
|
|
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. |
|
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
|
""" |
|
|
|
|
|
|
|
premise0 = self.embedding_layer(self.embedding(words1)) |
|
|
|
hypothesis0 = self.embedding_layer(self.embedding(words2)) |
|
|
|
|
|
|
|
if seq_len1 is not None: |
|
|
|
seq_len1 = seq_mask(seq_len1, premise0.size(1)) |
|
|
|
else: |
|
|
|
seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) |
|
|
|
seq_len1 = (seq_len1.long()).to(device=premise0.device) |
|
|
|
if seq_len2 is not None: |
|
|
|
seq_len2 = seq_mask(seq_len2, hypothesis0.size(1)) |
|
|
|
else: |
|
|
|
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) |
|
|
|
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) |
|
|
|
|
|
|
|
_BP, _PSL, _HP = premise0.size() |
|
|
|
_BH, _HSL, _HH = hypothesis0.size() |
|
|
|
_BPL, _PLL = seq_len1.size() |
|
|
@@ -109,6 +126,14 @@ class ESIM(BaseModel): |
|
|
|
return {'pred': prediction} |
|
|
|
|
|
|
|
def predict(self, words1, words2, seq_len1, seq_len2): |
|
|
|
""" Predict function |
|
|
|
|
|
|
|
:param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 |
|
|
|
:param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 |
|
|
|
:param torch.LongTensor seq_len1: [B] premise的长度 |
|
|
|
:param torch.LongTensor seq_len2: [B] hypothesis的长度 |
|
|
|
:return: dict prediction: [B, n_labels(N)] 预测结果 |
|
|
|
""" |
|
|
|
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] |
|
|
|
return {'pred': torch.argmax(prediction, dim=-1)} |
|
|
|
|