@@ -269,7 +269,7 @@ class ClassPreprocess(BasePreprocess): | |||||
for word in sent: | for word in sent: | ||||
if word not in word2index: | if word not in word2index: | ||||
word2index[word[0]] = len(word2index) | |||||
word2index[word] = len(word2index) | |||||
return word2index, label2index | return word2index, label2index | ||||
def to_index(self, data): | def to_index(self, data): | ||||
@@ -5,7 +5,7 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
# import torch.nn.functional as F | # import torch.nn.functional as F | ||||
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | |||||
import fastNLP.modules.encoder as encoder | |||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
@@ -18,22 +18,22 @@ class CNNText(torch.nn.Module): | |||||
def __init__(self, args): | def __init__(self, args): | ||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
class_num = args["num_classes"] | |||||
num_classes = args["num_classes"] | |||||
kernel_nums = [100, 100, 100] | kernel_nums = [100, 100, 100] | ||||
kernel_sizes = [3, 4, 5] | kernel_sizes = [3, 4, 5] | ||||
embed_num = args["vocab_size"] | |||||
vocab_size = args["vocab_size"] | |||||
embed_dim = 300 | embed_dim = 300 | ||||
pretrained_embed = None | pretrained_embed = None | ||||
drop_prob = 0.5 | drop_prob = 0.5 | ||||
# no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) | |||||
self.conv_pool = ConvMaxpool( | |||||
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim) | |||||
self.conv_pool = encoder.conv_maxpool.ConvMaxpool( | |||||
in_channels=embed_dim, | in_channels=embed_dim, | ||||
out_channels=kernel_nums, | out_channels=kernel_nums, | ||||
kernel_sizes=kernel_sizes) | kernel_sizes=kernel_sizes) | ||||
self.dropout = nn.Dropout(drop_prob) | self.dropout = nn.Dropout(drop_prob) | ||||
self.fc = nn.Linear(sum(kernel_nums), class_num) | |||||
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.embed(x) # [N,L] -> [N,L,C] | x = self.embed(x) # [N,L] -> [N,L,C] | ||||
@@ -2,8 +2,10 @@ from .embedding import Embedding | |||||
from .linear import Linear | from .linear import Linear | ||||
from .lstm import Lstm | from .lstm import Lstm | ||||
from .conv import Conv | from .conv import Conv | ||||
from .conv_maxpool import ConvMaxpool | |||||
__all__ = ["Lstm", | __all__ = ["Lstm", | ||||
"Embedding", | "Embedding", | ||||
"Linear", | "Linear", | ||||
"Conv"] | |||||
"Conv", | |||||
"ConvMaxpool"] |
@@ -4,6 +4,7 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.nn.init import xavier_uniform_ | |||||
class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
@@ -21,6 +22,7 @@ class ConvMaxpool(nn.Module): | |||||
if isinstance(kernel_sizes, int): | if isinstance(kernel_sizes, int): | ||||
out_channels = [out_channels] | out_channels = [out_channels] | ||||
kernel_sizes = [kernel_sizes] | kernel_sizes = [kernel_sizes] | ||||
self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
in_channels=in_channels, | in_channels=in_channels, | ||||
out_channels=oc, | out_channels=oc, | ||||
@@ -31,6 +33,9 @@ class ConvMaxpool(nn.Module): | |||||
groups=groups, | groups=groups, | ||||
bias=bias) | bias=bias) | ||||
for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
for conv in self.convs: | |||||
xavier_uniform_(conv.weight) # weight initialization | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||