diff --git a/utils/masking.py b/utils/masking.py index 6d58992..e5d041f 100644 --- a/utils/masking.py +++ b/utils/masking.py @@ -1,7 +1,7 @@ import mindspore import mindspore.ops.operations as P import mindspore.numpy as mnp - +1 class TriangularCausalMask(): def __init__(self, B, L, device="cpu"): mask_shape = [B, 1, L, L]