|
|
@@ -2,8 +2,8 @@ import torch |
|
|
|
import torch.nn as nn |
|
|
|
from fastNLP.core.const import Const as C |
|
|
|
from fastNLP.modules.encoder.lstm import LSTM |
|
|
|
from fastNLP.modules import encoder |
|
|
|
from fastNLP.modules.aggregator.attention import SelfAttention |
|
|
|
from fastNLP.embeddings.utils import get_embeddings |
|
|
|
from fastNLP.modules.encoder.attention import SelfAttention |
|
|
|
from fastNLP.modules.decoder.mlp import MLP |
|
|
|
|
|
|
|
|
|
|
@@ -16,7 +16,7 @@ class BiLSTM_SELF_ATTENTION(nn.Module): |
|
|
|
attention_hops=1, |
|
|
|
nfc=128): |
|
|
|
super(BiLSTM_SELF_ATTENTION,self).__init__() |
|
|
|
self.embed = encoder.Embedding(init_embed) |
|
|
|
self.embed = get_embeddings(init_embed) |
|
|
|
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) |
|
|
|
self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) |
|
|
|
self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) |
|
|
|