| @@ -269,7 +269,7 @@ class ClassPreprocess(BasePreprocess): | |||||
| for word in sent: | for word in sent: | ||||
| if word not in word2index: | if word not in word2index: | ||||
| word2index[word[0]] = len(word2index) | |||||
| word2index[word] = len(word2index) | |||||
| return word2index, label2index | return word2index, label2index | ||||
| def to_index(self, data): | def to_index(self, data): | ||||
| @@ -5,7 +5,7 @@ import torch | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| # import torch.nn.functional as F | # import torch.nn.functional as F | ||||
| from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | |||||
| import fastNLP.modules.encoder as encoder | |||||
| class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
| @@ -18,22 +18,22 @@ class CNNText(torch.nn.Module): | |||||
| def __init__(self, args): | def __init__(self, args): | ||||
| super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
| class_num = args["num_classes"] | |||||
| num_classes = args["num_classes"] | |||||
| kernel_nums = [100, 100, 100] | kernel_nums = [100, 100, 100] | ||||
| kernel_sizes = [3, 4, 5] | kernel_sizes = [3, 4, 5] | ||||
| embed_num = args["vocab_size"] | |||||
| vocab_size = args["vocab_size"] | |||||
| embed_dim = 300 | embed_dim = 300 | ||||
| pretrained_embed = None | pretrained_embed = None | ||||
| drop_prob = 0.5 | drop_prob = 0.5 | ||||
| # no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
| self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) | |||||
| self.conv_pool = ConvMaxpool( | |||||
| self.embed = encoder.embedding.Embedding(vocab_size, embed_dim) | |||||
| self.conv_pool = encoder.conv_maxpool.ConvMaxpool( | |||||
| in_channels=embed_dim, | in_channels=embed_dim, | ||||
| out_channels=kernel_nums, | out_channels=kernel_nums, | ||||
| kernel_sizes=kernel_sizes) | kernel_sizes=kernel_sizes) | ||||
| self.dropout = nn.Dropout(drop_prob) | self.dropout = nn.Dropout(drop_prob) | ||||
| self.fc = nn.Linear(sum(kernel_nums), class_num) | |||||
| self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = self.embed(x) # [N,L] -> [N,L,C] | x = self.embed(x) # [N,L] -> [N,L,C] | ||||
| @@ -2,8 +2,10 @@ from .embedding import Embedding | |||||
| from .linear import Linear | from .linear import Linear | ||||
| from .lstm import Lstm | from .lstm import Lstm | ||||
| from .conv import Conv | from .conv import Conv | ||||
| from .conv_maxpool import ConvMaxpool | |||||
| __all__ = ["Lstm", | __all__ = ["Lstm", | ||||
| "Embedding", | "Embedding", | ||||
| "Linear", | "Linear", | ||||
| "Conv"] | |||||
| "Conv", | |||||
| "ConvMaxpool"] | |||||
| @@ -4,6 +4,7 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch.nn.init import xavier_uniform_ | |||||
| class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
| @@ -21,6 +22,7 @@ class ConvMaxpool(nn.Module): | |||||
| if isinstance(kernel_sizes, int): | if isinstance(kernel_sizes, int): | ||||
| out_channels = [out_channels] | out_channels = [out_channels] | ||||
| kernel_sizes = [kernel_sizes] | kernel_sizes = [kernel_sizes] | ||||
| self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
| in_channels=in_channels, | in_channels=in_channels, | ||||
| out_channels=oc, | out_channels=oc, | ||||
| @@ -31,6 +33,9 @@ class ConvMaxpool(nn.Module): | |||||
| groups=groups, | groups=groups, | ||||
| bias=bias) | bias=bias) | ||||
| for oc, ks in zip(out_channels, kernel_sizes)]) | for oc, ks in zip(out_channels, kernel_sizes)]) | ||||
| for conv in self.convs: | |||||
| xavier_uniform_(conv.weight) # weight initialization | |||||
| else: | else: | ||||
| raise Exception( | raise Exception( | ||||
| 'Incorrect kernel sizes: should be list, tuple or int') | 'Incorrect kernel sizes: should be list, tuple or int') | ||||