|
|
@@ -1,7 +1,7 @@ |
|
|
import mindspore |
|
|
import mindspore |
|
|
import mindspore.ops.operations as P |
|
|
import mindspore.ops.operations as P |
|
|
import mindspore.numpy as mnp |
|
|
import mindspore.numpy as mnp |
|
|
|
|
|
|
|
|
|
|
|
1 |
|
|
class TriangularCausalMask(): |
|
|
class TriangularCausalMask(): |
|
|
def __init__(self, B, L, device="cpu"): |
|
|
def __init__(self, B, L, device="cpu"): |
|
|
mask_shape = [B, 1, L, L] |
|
|
mask_shape = [B, 1, L, L] |
|
|
|