|
|
@@ -2,7 +2,8 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
from torch.autograd import Variable |
|
|
|
|
|
|
|
class Selfattention(nn.Module): |
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
""" |
|
|
|
Self Attention Module. |
|
|
|
|
|
|
@@ -12,7 +13,7 @@ class Selfattention(nn.Module): |
|
|
|
r : the number of encoded vectors |
|
|
|
""" |
|
|
|
def __init__(self, input_size, d_a, r): |
|
|
|
super(Selfattention, self).__init__() |
|
|
|
super(SelfAttention, self).__init__() |
|
|
|
self.W_s1 = nn.Parameter(torch.randn(d_a, input_size), requires_grad=True) |
|
|
|
self.W_s2 = nn.Parameter(torch.randn(r, d_a), requires_grad=True) |
|
|
|
self.softmax = nn.Softmax(dim=2) |