修改代码以适配新fastNLP.embeddings模块tags/v0.4.10
@@ -2,7 +2,7 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
from .awdlstm_module import LSTM | from .awdlstm_module import LSTM | ||||
from fastNLP.modules import encoder | |||||
from fastNLP.embeddings.utils import get_embeddings | |||||
from fastNLP.modules.decoder.mlp import MLP | from fastNLP.modules.decoder.mlp import MLP | ||||
@@ -14,7 +14,7 @@ class AWDLSTMSentiment(nn.Module): | |||||
nfc=128, | nfc=128, | ||||
wdrop=0.5): | wdrop=0.5): | ||||
super(AWDLSTMSentiment,self).__init__() | super(AWDLSTMSentiment,self).__init__() | ||||
self.embed = encoder.Embedding(init_embed) | |||||
self.embed = get_embeddings(init_embed) | |||||
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) | self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) | ||||
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | ||||
@@ -2,7 +2,7 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
from fastNLP.modules.encoder.lstm import LSTM | from fastNLP.modules.encoder.lstm import LSTM | ||||
from fastNLP.modules import encoder | |||||
from fastNLP.embeddings.utils import get_embeddings | |||||
from fastNLP.modules.decoder.mlp import MLP | from fastNLP.modules.decoder.mlp import MLP | ||||
@@ -13,7 +13,7 @@ class BiLSTMSentiment(nn.Module): | |||||
num_layers=1, | num_layers=1, | ||||
nfc=128): | nfc=128): | ||||
super(BiLSTMSentiment,self).__init__() | super(BiLSTMSentiment,self).__init__() | ||||
self.embed = encoder.Embedding(init_embed) | |||||
self.embed = get_embeddings(init_embed) | |||||
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | ||||
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) | ||||
@@ -2,8 +2,8 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
from fastNLP.modules.encoder.lstm import LSTM | from fastNLP.modules.encoder.lstm import LSTM | ||||
from fastNLP.modules import encoder | |||||
from fastNLP.modules.aggregator.attention import SelfAttention | |||||
from fastNLP.embeddings.utils import get_embeddings | |||||
from fastNLP.modules.encoder.attention import SelfAttention | |||||
from fastNLP.modules.decoder.mlp import MLP | from fastNLP.modules.decoder.mlp import MLP | ||||
@@ -16,7 +16,7 @@ class BiLSTM_SELF_ATTENTION(nn.Module): | |||||
attention_hops=1, | attention_hops=1, | ||||
nfc=128): | nfc=128): | ||||
super(BiLSTM_SELF_ATTENTION,self).__init__() | super(BiLSTM_SELF_ATTENTION,self).__init__() | ||||
self.embed = encoder.Embedding(init_embed) | |||||
self.embed = get_embeddings(init_embed) | |||||
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) | ||||
self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) | self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) | ||||
self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) | self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) | ||||
@@ -5,7 +5,7 @@ import os | |||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | ||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | ||||
from data.IMDBLoader import IMDBLoader | |||||
from fastNLP.io.data_loader import IMDBLoader | |||||
from fastNLP.embeddings import StaticEmbedding | from fastNLP.embeddings import StaticEmbedding | ||||
from model.awd_lstm import AWDLSTMSentiment | from model.awd_lstm import AWDLSTMSentiment | ||||