diff --git a/models/attn.py b/models/attn.py index 5057c36..9569477 100644 --- a/models/attn.py +++ b/models/attn.py @@ -4,6 +4,7 @@ import mindspore.ops as ops from mindspore import Tensor from mindspore.common import dtype as mstype + class FullAttention(nn.Module): def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): super(FullAttention, self).__init__()