Browse Source

修改代码以适配新embeddings模块

tags/v0.4.10
lyhuang18 6 years ago
parent
commit
ffbba0f895
4 changed files with 8 additions and 8 deletions
  1. +2
    -2
      reproduction/text_classification/model/awd_lstm.py
  2. +2
    -2
      reproduction/text_classification/model/lstm.py
  3. +3
    -3
      reproduction/text_classification/model/lstm_self_attention.py
  4. +1
    -1
      reproduction/text_classification/train_awdlstm.py

+ 2
- 2
reproduction/text_classification/model/awd_lstm.py View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from .awdlstm_module import LSTM
from fastNLP.modules import encoder
from fastNLP.embeddings.utils import get_embeddings
from fastNLP.modules.decoder.mlp import MLP


@@ -14,7 +14,7 @@ class AWDLSTMSentiment(nn.Module):
nfc=128,
wdrop=0.5):
super(AWDLSTMSentiment,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.embed = get_embeddings(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop)
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes])



+ 2
- 2
reproduction/text_classification/model/lstm.py View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.modules import encoder
from fastNLP.embeddings.utils import get_embeddings
from fastNLP.modules.decoder.mlp import MLP


@@ -13,7 +13,7 @@ class BiLSTMSentiment(nn.Module):
num_layers=1,
nfc=128):
super(BiLSTMSentiment,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.embed = get_embeddings(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True)
self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes])



+ 3
- 3
reproduction/text_classification/model/lstm_self_attention.py View File

@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from fastNLP.core.const import Const as C
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.modules import encoder
from fastNLP.modules.aggregator.attention import SelfAttention
from fastNLP.embeddings.utils import get_embeddings
from fastNLP.modules.encoder.attention import SelfAttention
from fastNLP.modules.decoder.mlp import MLP


@@ -16,7 +16,7 @@ class BiLSTM_SELF_ATTENTION(nn.Module):
attention_hops=1,
nfc=128):
super(BiLSTM_SELF_ATTENTION,self).__init__()
self.embed = encoder.Embedding(init_embed)
self.embed = get_embeddings(init_embed)
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True)
self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops)
self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes])


+ 1
- 1
reproduction/text_classification/train_awdlstm.py View File

@@ -5,7 +5,7 @@ import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'

from data.IMDBLoader import IMDBLoader
from fastNLP.io.data_loader import IMDBLoader
from fastNLP.embeddings import StaticEmbedding
from model.awd_lstm import AWDLSTMSentiment



Loading…
Cancel
Save