|
|
@@ -152,8 +152,7 @@ class BiAttention(nn.Module): |
|
|
|
:param torch.Tensor premise_mask: [batch_size, a_seq_len] |
|
|
|
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] |
|
|
|
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] |
|
|
|
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] |
|
|
|
torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] |
|
|
|
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] |
|
|
|
""" |
|
|
|
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) |
|
|
|
.contiguous()) |
|
|
|