import mindspore import mindspore.ops.operations as P import mindspore.numpy as mnp class TriangularCausalMask(): def __init__(self, B, L, device="cpu"): mask_shape = [B, 1, L, L] self._mask = mindspore.Tensor(mnp.triu(mnp.ones(mask_shape, dtype=mnp.bool_), 1).to(device)) @property def mask(self): return self._mask class ProbMask(): def __init__(self, B, H, L, index, scores, device="cpu"): _mask = mindspore.Tensor(mnp.triu(mnp.ones((L, scores.shape[-1]), dtype=mnp.bool_)).to(device)) _mask_ex = _mask[None, None, :] indicator = _mask_ex[mnp.arange(B)[:, None, None], mnp.arange(H)[None, :, None], index, :].to(device) self._mask = indicator.view(scores.shape).to(device) @property def mask(self): return self._mask