|
|
@@ -15,7 +15,7 @@ def seq_len_to_byte_mask(seq_lens): |
|
|
|
# return value: ByteTensor, batch_size x max_len |
|
|
|
batch_size = seq_lens.size(0) |
|
|
|
max_len = seq_lens.max() |
|
|
|
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1) |
|
|
|
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) |
|
|
|
mask = broadcast_arange.lt(seq_lens.float().view(-1, 1)) |
|
|
|
return mask |
|
|
|
|
|
|
|