|
@@ -57,6 +57,7 @@ class StarTransformer(nn.Module): |
|
|
nodes = embs |
|
|
nodes = embs |
|
|
relay = embs.mean(2, keepdim=True) |
|
|
relay = embs.mean(2, keepdim=True) |
|
|
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) |
|
|
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) |
|
|
|
|
|
ex_mask = ex_mask.ne(1) # reverse mask for next masked_fill. |
|
|
r_embs = embs.view(B, H, 1, L) |
|
|
r_embs = embs.view(B, H, 1, L) |
|
|
for i in range(self.iters): |
|
|
for i in range(self.iters): |
|
|
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) |
|
|
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) |
|
|