|
|
@@ -2,7 +2,6 @@ import mindspore |
|
|
import mindspore.ops.operations as P |
|
|
import mindspore.ops.operations as P |
|
|
import mindspore.numpy as mnp |
|
|
import mindspore.numpy as mnp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TriangularCausalMask(): |
|
|
class TriangularCausalMask(): |
|
|
def __init__(self, B, L, device="cpu"): |
|
|
def __init__(self, B, L, device="cpu"): |
|
|
mask_shape = [B, 1, L, L] |
|
|
mask_shape = [B, 1, L, L] |
|
|
|