From ffbba0f895c413bf8c55dad196d2cd4d586f10ee Mon Sep 17 00:00:00 2001 From: lyhuang18 <42239874+lyhuang18@users.noreply.github.com> Date: Fri, 12 Jul 2019 06:47:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81=E4=BB=A5?= =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0embeddings=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- reproduction/text_classification/model/awd_lstm.py | 4 ++-- reproduction/text_classification/model/lstm.py | 4 ++-- .../text_classification/model/lstm_self_attention.py | 6 +++--- reproduction/text_classification/train_awdlstm.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/reproduction/text_classification/model/awd_lstm.py b/reproduction/text_classification/model/awd_lstm.py index 0d8f711a..c9c8a153 100644 --- a/reproduction/text_classification/model/awd_lstm.py +++ b/reproduction/text_classification/model/awd_lstm.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from fastNLP.core.const import Const as C from .awdlstm_module import LSTM -from fastNLP.modules import encoder +from fastNLP.embeddings.utils import get_embeddings from fastNLP.modules.decoder.mlp import MLP @@ -14,7 +14,7 @@ class AWDLSTMSentiment(nn.Module): nfc=128, wdrop=0.5): super(AWDLSTMSentiment,self).__init__() - self.embed = encoder.Embedding(init_embed) + self.embed = get_embeddings(init_embed) self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop) self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) diff --git a/reproduction/text_classification/model/lstm.py b/reproduction/text_classification/model/lstm.py index 388f3f1c..fd1089dd 100644 --- a/reproduction/text_classification/model/lstm.py +++ b/reproduction/text_classification/model/lstm.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from fastNLP.core.const import Const as C from fastNLP.modules.encoder.lstm import LSTM -from fastNLP.modules import encoder +from fastNLP.embeddings.utils import get_embeddings from fastNLP.modules.decoder.mlp import MLP @@ -13,7 +13,7 @@ class BiLSTMSentiment(nn.Module): num_layers=1, nfc=128): super(BiLSTMSentiment,self).__init__() - self.embed = encoder.Embedding(init_embed) + self.embed = get_embeddings(init_embed) self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes]) diff --git a/reproduction/text_classification/model/lstm_self_attention.py b/reproduction/text_classification/model/lstm_self_attention.py index 239635fe..9a39049d 100644 --- a/reproduction/text_classification/model/lstm_self_attention.py +++ b/reproduction/text_classification/model/lstm_self_attention.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from fastNLP.core.const import Const as C from fastNLP.modules.encoder.lstm import LSTM -from fastNLP.modules import encoder -from fastNLP.modules.aggregator.attention import SelfAttention +from fastNLP.embeddings.utils import get_embeddings +from fastNLP.modules.encoder.attention import SelfAttention from fastNLP.modules.decoder.mlp import MLP @@ -16,7 +16,7 @@ class BiLSTM_SELF_ATTENTION(nn.Module): attention_hops=1, nfc=128): super(BiLSTM_SELF_ATTENTION,self).__init__() - self.embed = encoder.Embedding(init_embed) + self.embed = get_embeddings(init_embed) self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True) self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops) self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes]) diff --git a/reproduction/text_classification/train_awdlstm.py b/reproduction/text_classification/train_awdlstm.py index b0f2af49..b2a67fdb 100644 --- a/reproduction/text_classification/train_awdlstm.py +++ b/reproduction/text_classification/train_awdlstm.py @@ -5,7 +5,7 @@ import os os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' -from data.IMDBLoader import IMDBLoader +from fastNLP.io.data_loader import IMDBLoader from fastNLP.embeddings import StaticEmbedding from model.awd_lstm import AWDLSTMSentiment