diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index 769dc42a..d72d2022 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -36,6 +36,7 @@ __all__ = [ "MaxPool", "MaxPoolWithMask", + "KMaxPool", "AvgPool", "AvgPoolWithMask", diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 7fbc4b71..cbb42d7e 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -23,6 +23,7 @@ __all__ = [ "MaxPool", "MaxPoolWithMask", + "KMaxPool", "AvgPool", "AvgPoolWithMask", @@ -36,7 +37,7 @@ from .bert import BertModel from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool from .lstm import LSTM -from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask +from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool from .star_transformer import StarTransformer from .transformer import TransformerEncoder from .variational_rnn import VarRNN, VarLSTM, VarGRU diff --git a/fastNLP/modules/encoder/pooling.py b/fastNLP/modules/encoder/pooling.py index 789b6d26..80ff419d 100644 --- a/fastNLP/modules/encoder/pooling.py +++ b/fastNLP/modules/encoder/pooling.py @@ -3,6 +3,7 @@ __all__ = [ "MaxPool", "MaxPoolWithMask", + "KMaxPool", "AvgPool", "AvgPoolWithMask" ] @@ -27,7 +28,7 @@ class MaxPool(nn.Module): :param ceil_mode: """ super(MaxPool, self).__init__() - assert (1 <= dimension) and (dimension <= 3) + assert dimension in [1, 2, 3], f'Now we only support 1d, 2d, or 3d Pooling' self.dimension = dimension self.stride = stride self.padding = padding @@ -37,12 +38,12 @@ class MaxPool(nn.Module): def forward(self, x): if self.dimension == 1: + x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] pooling = nn.MaxPool1d( stride=self.stride, padding=self.padding, dilation=self.dilation, kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), return_indices=False, ceil_mode=self.ceil_mode ) - x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] elif self.dimension == 2: pooling = nn.MaxPool2d( stride=self.stride, padding=self.padding, dilation=self.dilation, @@ -50,7 +51,7 @@ class MaxPool(nn.Module): return_indices=False, ceil_mode=self.ceil_mode ) else: - pooling = nn.MaxPool2d( + pooling = nn.MaxPool3d( stride=self.stride, padding=self.padding, dilation=self.dilation, kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), return_indices=False, ceil_mode=self.ceil_mode diff --git a/test/modules/encoder/test_pooling.py b/test/modules/encoder/test_pooling.py new file mode 100644 index 00000000..5adca4ff --- /dev/null +++ b/test/modules/encoder/test_pooling.py @@ -0,0 +1,41 @@ +import unittest + +import torch + +from fastNLP.modules.encoder.pooling import MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask + + +class TestPooling(unittest.TestCase): + def test_MaxPool(self): + max_pool_1d = MaxPool(dimension=1) + x = torch.randn(5, 6, 7) + self.assertEqual(max_pool_1d(x).size(), (5, 7)) + + max_pool_2d = MaxPool(dimension=2) + self.assertEqual(max_pool_2d(x).size(), (5, 1)) + + max_pool_3d = MaxPool(dimension=3) + x = torch.randn(4, 5, 6, 7) + self.assertEqual(max_pool_3d(x).size(), (4, 1, 1)) + + def test_MaxPoolWithMask(self): + pool = MaxPoolWithMask() + x = torch.randn(5, 6, 7) + mask = (torch.randn(5, 6) > 0).long() + self.assertEqual(pool(x, mask).size(), (5, 7)) + + def test_KMaxPool(self): + k_pool = KMaxPool(k=3) + x = torch.randn(4, 5, 6) + self.assertEqual(k_pool(x).size(), (4, 15)) + + def test_AvgPool(self): + pool = AvgPool() + x = torch.randn(4, 5, 6) + self.assertEqual(pool(x).size(), (4, 5)) + + def test_AvgPoolWithMask(self): + pool = AvgPoolWithMask() + x = torch.randn(5, 6, 7) + mask = (torch.randn(5, 6) > 0).long() + self.assertEqual(pool(x, mask).size(), (5, 7))