@@ -36,6 +36,7 @@ __all__ = [ | |||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"KMaxPool", | |||||
"AvgPool", | "AvgPool", | ||||
"AvgPoolWithMask", | "AvgPoolWithMask", | ||||
@@ -23,6 +23,7 @@ __all__ = [ | |||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"KMaxPool", | |||||
"AvgPool", | "AvgPool", | ||||
"AvgPoolWithMask", | "AvgPoolWithMask", | ||||
@@ -36,7 +37,7 @@ from .bert import BertModel | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .lstm import LSTM | from .lstm import LSTM | ||||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool | |||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -3,6 +3,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
"KMaxPool", | |||||
"AvgPool", | "AvgPool", | ||||
"AvgPoolWithMask" | "AvgPoolWithMask" | ||||
] | ] | ||||
@@ -27,7 +28,7 @@ class MaxPool(nn.Module): | |||||
:param ceil_mode: | :param ceil_mode: | ||||
""" | """ | ||||
super(MaxPool, self).__init__() | super(MaxPool, self).__init__() | ||||
assert (1 <= dimension) and (dimension <= 3) | |||||
assert dimension in [1, 2, 3], f'Now we only support 1d, 2d, or 3d Pooling' | |||||
self.dimension = dimension | self.dimension = dimension | ||||
self.stride = stride | self.stride = stride | ||||
self.padding = padding | self.padding = padding | ||||
@@ -37,12 +38,12 @@ class MaxPool(nn.Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
if self.dimension == 1: | if self.dimension == 1: | ||||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | |||||
pooling = nn.MaxPool1d( | pooling = nn.MaxPool1d( | ||||
stride=self.stride, padding=self.padding, dilation=self.dilation, | stride=self.stride, padding=self.padding, dilation=self.dilation, | ||||
kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), | kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), | ||||
return_indices=False, ceil_mode=self.ceil_mode | return_indices=False, ceil_mode=self.ceil_mode | ||||
) | ) | ||||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | |||||
elif self.dimension == 2: | elif self.dimension == 2: | ||||
pooling = nn.MaxPool2d( | pooling = nn.MaxPool2d( | ||||
stride=self.stride, padding=self.padding, dilation=self.dilation, | stride=self.stride, padding=self.padding, dilation=self.dilation, | ||||
@@ -50,7 +51,7 @@ class MaxPool(nn.Module): | |||||
return_indices=False, ceil_mode=self.ceil_mode | return_indices=False, ceil_mode=self.ceil_mode | ||||
) | ) | ||||
else: | else: | ||||
pooling = nn.MaxPool2d( | |||||
pooling = nn.MaxPool3d( | |||||
stride=self.stride, padding=self.padding, dilation=self.dilation, | stride=self.stride, padding=self.padding, dilation=self.dilation, | ||||
kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), | kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), | ||||
return_indices=False, ceil_mode=self.ceil_mode | return_indices=False, ceil_mode=self.ceil_mode | ||||
@@ -0,0 +1,41 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.pooling import MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask | |||||
class TestPooling(unittest.TestCase): | |||||
def test_MaxPool(self): | |||||
max_pool_1d = MaxPool(dimension=1) | |||||
x = torch.randn(5, 6, 7) | |||||
self.assertEqual(max_pool_1d(x).size(), (5, 7)) | |||||
max_pool_2d = MaxPool(dimension=2) | |||||
self.assertEqual(max_pool_2d(x).size(), (5, 1)) | |||||
max_pool_3d = MaxPool(dimension=3) | |||||
x = torch.randn(4, 5, 6, 7) | |||||
self.assertEqual(max_pool_3d(x).size(), (4, 1, 1)) | |||||
def test_MaxPoolWithMask(self): | |||||
pool = MaxPoolWithMask() | |||||
x = torch.randn(5, 6, 7) | |||||
mask = (torch.randn(5, 6) > 0).long() | |||||
self.assertEqual(pool(x, mask).size(), (5, 7)) | |||||
def test_KMaxPool(self): | |||||
k_pool = KMaxPool(k=3) | |||||
x = torch.randn(4, 5, 6) | |||||
self.assertEqual(k_pool(x).size(), (4, 15)) | |||||
def test_AvgPool(self): | |||||
pool = AvgPool() | |||||
x = torch.randn(4, 5, 6) | |||||
self.assertEqual(pool(x).size(), (4, 5)) | |||||
def test_AvgPoolWithMask(self): | |||||
pool = AvgPoolWithMask() | |||||
x = torch.randn(5, 6, 7) | |||||
mask = (torch.randn(5, 6) > 0).long() | |||||
self.assertEqual(pool(x, mask).size(), (5, 7)) |