|
|
@@ -46,6 +46,7 @@ class StarTransformer(nn.Module): |
|
|
|
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
B, L, H = data.size() |
|
|
|
mask = (mask == 0) # flip the mask for masked_fill_ |
|
|
|
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) |
|
|
|
|
|
|
|
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 |
|
|
|