|
@@ -7,7 +7,6 @@ import numpy as NP |
|
|
class StarTransformer(nn.Module): |
|
|
class StarTransformer(nn.Module): |
|
|
"""Star-Transformer Encoder part。 |
|
|
"""Star-Transformer Encoder part。 |
|
|
paper: https://arxiv.org/abs/1902.09113 |
|
|
paper: https://arxiv.org/abs/1902.09113 |
|
|
|
|
|
|
|
|
:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 |
|
|
:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 |
|
|
:param num_layers: int, star-transformer的层数 |
|
|
:param num_layers: int, star-transformer的层数 |
|
|
:param num_head: int,head的数量。 |
|
|
:param num_head: int,head的数量。 |
|
@@ -46,6 +45,7 @@ class StarTransformer(nn.Module): |
|
|
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
B, L, H = data.size() |
|
|
B, L, H = data.size() |
|
|
|
|
|
mask = (mask == 0) # flip the mask for masked_fill_ |
|
|
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) |
|
|
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) |
|
|
|
|
|
|
|
|
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 |
|
|
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 |
|
@@ -57,7 +57,6 @@ class StarTransformer(nn.Module): |
|
|
nodes = embs |
|
|
nodes = embs |
|
|
relay = embs.mean(2, keepdim=True) |
|
|
relay = embs.mean(2, keepdim=True) |
|
|
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) |
|
|
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) |
|
|
ex_mask = ex_mask.ne(1) # reverse mask for next masked_fill. |
|
|
|
|
|
r_embs = embs.view(B, H, 1, L) |
|
|
r_embs = embs.view(B, H, 1, L) |
|
|
for i in range(self.iters): |
|
|
for i in range(self.iters): |
|
|
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) |
|
|
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) |
|
@@ -137,11 +136,10 @@ class MSA2(nn.Module): |
|
|
|
|
|
|
|
|
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h |
|
|
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h |
|
|
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L |
|
|
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L |
|
|
v = k.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h |
|
|
|
|
|
|
|
|
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h |
|
|
pre_a = torch.matmul(q, k) / NP.sqrt(head_dim) |
|
|
pre_a = torch.matmul(q, k) / NP.sqrt(head_dim) |
|
|
if mask is not None: |
|
|
if mask is not None: |
|
|
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) |
|
|
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) |
|
|
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L |
|
|
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L |
|
|
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 |
|
|
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 |
|
|
return self.WO(att) |
|
|
return self.WO(att) |
|
|
|
|
|
|