From d910ae3c7725737f49e2de5b251ae76af80ffc8c Mon Sep 17 00:00:00 2001 From: Ke Zhen Date: Sun, 2 Sep 2018 16:39:36 +0800 Subject: [PATCH] Rewrite classification model, add intialization for conv_maxpool --- fastNLP/core/preprocess.py | 2 +- fastNLP/models/cnn_text_classification.py | 12 ++++++------ fastNLP/modules/encoder/__init__.py | 4 +++- fastNLP/modules/encoder/conv_maxpool.py | 5 +++++ 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index 77df6b51..0cfe526e 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -269,7 +269,7 @@ class ClassPreprocess(BasePreprocess): for word in sent: if word not in word2index: - word2index[word[0]] = len(word2index) + word2index[word] = len(word2index) return word2index, label2index def to_index(self, data): diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index b6dcafb3..fc7388a5 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn # import torch.nn.functional as F -from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool +import fastNLP.modules.encoder as encoder class CNNText(torch.nn.Module): @@ -18,22 +18,22 @@ class CNNText(torch.nn.Module): def __init__(self, args): super(CNNText, self).__init__() - class_num = args["num_classes"] + num_classes = args["num_classes"] kernel_nums = [100, 100, 100] kernel_sizes = [3, 4, 5] - embed_num = args["vocab_size"] + vocab_size = args["vocab_size"] embed_dim = 300 pretrained_embed = None drop_prob = 0.5 # no support for pre-trained embedding currently - self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) - self.conv_pool = ConvMaxpool( + self.embed = encoder.embedding.Embedding(vocab_size, embed_dim) + self.conv_pool = encoder.conv_maxpool.ConvMaxpool( in_channels=embed_dim, out_channels=kernel_nums, kernel_sizes=kernel_sizes) self.dropout = nn.Dropout(drop_prob) - self.fc = nn.Linear(sum(kernel_nums), class_num) + self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) def forward(self, x): x = self.embed(x) # [N,L] -> [N,L,C] diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index b4e689a7..71b786b9 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -2,8 +2,10 @@ from .embedding import Embedding from .linear import Linear from .lstm import Lstm from .conv import Conv +from .conv_maxpool import ConvMaxpool __all__ = ["Lstm", "Embedding", "Linear", - "Conv"] + "Conv", + "ConvMaxpool"] diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py index 0012dce7..f666e7f9 100644 --- a/fastNLP/modules/encoder/conv_maxpool.py +++ b/fastNLP/modules/encoder/conv_maxpool.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.init import xavier_uniform_ class ConvMaxpool(nn.Module): @@ -21,6 +22,7 @@ class ConvMaxpool(nn.Module): if isinstance(kernel_sizes, int): out_channels = [out_channels] kernel_sizes = [kernel_sizes] + self.convs = nn.ModuleList([nn.Conv1d( in_channels=in_channels, out_channels=oc, @@ -31,6 +33,9 @@ class ConvMaxpool(nn.Module): groups=groups, bias=bias) for oc, ks in zip(out_channels, kernel_sizes)]) + + for conv in self.convs: + xavier_uniform_(conv.weight) # weight initialization else: raise Exception( 'Incorrect kernel sizes: should be list, tuple or int')