|
|
@@ -1,7 +1,8 @@ |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.functional as F |
|
|
1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
|
class DecoderLayer(nn.Module): |
|
|
def __init__(self, self_attention, cross_attention, d_model, d_ff=None, |
|
|
def __init__(self, self_attention, cross_attention, d_model, d_ff=None, |
|
|
dropout=0.1, activation="relu"): |
|
|
dropout=0.1, activation="relu"): |
|
|
|