|
|
@@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): |
|
|
|
if isinstance(seq_len, jittor.Var): |
|
|
|
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." |
|
|
|
batch_size = seq_len.shape[0] |
|
|
|
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) |
|
|
|
broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1) |
|
|
|
mask = broad_cast_seq_len < seq_len.unsqueeze(1) |
|
|
|
return mask |
|
|
|
except NameError as e: |
|
|
|