Browse Source

add test code for testing pooling

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
4e2ca6c95a
4 changed files with 48 additions and 4 deletions
  1. +1
    -0
      fastNLP/modules/__init__.py
  2. +2
    -1
      fastNLP/modules/encoder/__init__.py
  3. +4
    -3
      fastNLP/modules/encoder/pooling.py
  4. +41
    -0
      test/modules/encoder/test_pooling.py

+ 1
- 0
fastNLP/modules/__init__.py View File

@@ -36,6 +36,7 @@ __all__ = [


"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"KMaxPool",
"AvgPool", "AvgPool",
"AvgPoolWithMask", "AvgPoolWithMask",




+ 2
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -23,6 +23,7 @@ __all__ = [


"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"KMaxPool",
"AvgPool", "AvgPool",
"AvgPoolWithMask", "AvgPoolWithMask",


@@ -36,7 +37,7 @@ from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool from .conv_maxpool import ConvMaxpool
from .lstm import LSTM from .lstm import LSTM
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool
from .star_transformer import StarTransformer from .star_transformer import StarTransformer
from .transformer import TransformerEncoder from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU from .variational_rnn import VarRNN, VarLSTM, VarGRU

+ 4
- 3
fastNLP/modules/encoder/pooling.py View File

@@ -3,6 +3,7 @@
__all__ = [ __all__ = [
"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"KMaxPool",
"AvgPool", "AvgPool",
"AvgPoolWithMask" "AvgPoolWithMask"
] ]
@@ -27,7 +28,7 @@ class MaxPool(nn.Module):
:param ceil_mode: :param ceil_mode:
""" """
super(MaxPool, self).__init__() super(MaxPool, self).__init__()
assert (1 <= dimension) and (dimension <= 3)
assert dimension in [1, 2, 3], f'Now we only support 1d, 2d, or 3d Pooling'
self.dimension = dimension self.dimension = dimension
self.stride = stride self.stride = stride
self.padding = padding self.padding = padding
@@ -37,12 +38,12 @@ class MaxPool(nn.Module):


def forward(self, x): def forward(self, x):
if self.dimension == 1: if self.dimension == 1:
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
pooling = nn.MaxPool1d( pooling = nn.MaxPool1d(
stride=self.stride, padding=self.padding, dilation=self.dilation, stride=self.stride, padding=self.padding, dilation=self.dilation,
kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1),
return_indices=False, ceil_mode=self.ceil_mode return_indices=False, ceil_mode=self.ceil_mode
) )
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
elif self.dimension == 2: elif self.dimension == 2:
pooling = nn.MaxPool2d( pooling = nn.MaxPool2d(
stride=self.stride, padding=self.padding, dilation=self.dilation, stride=self.stride, padding=self.padding, dilation=self.dilation,
@@ -50,7 +51,7 @@ class MaxPool(nn.Module):
return_indices=False, ceil_mode=self.ceil_mode return_indices=False, ceil_mode=self.ceil_mode
) )
else: else:
pooling = nn.MaxPool2d(
pooling = nn.MaxPool3d(
stride=self.stride, padding=self.padding, dilation=self.dilation, stride=self.stride, padding=self.padding, dilation=self.dilation,
kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)),
return_indices=False, ceil_mode=self.ceil_mode return_indices=False, ceil_mode=self.ceil_mode


+ 41
- 0
test/modules/encoder/test_pooling.py View File

@@ -0,0 +1,41 @@
import unittest

import torch

from fastNLP.modules.encoder.pooling import MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask


class TestPooling(unittest.TestCase):
def test_MaxPool(self):
max_pool_1d = MaxPool(dimension=1)
x = torch.randn(5, 6, 7)
self.assertEqual(max_pool_1d(x).size(), (5, 7))

max_pool_2d = MaxPool(dimension=2)
self.assertEqual(max_pool_2d(x).size(), (5, 1))

max_pool_3d = MaxPool(dimension=3)
x = torch.randn(4, 5, 6, 7)
self.assertEqual(max_pool_3d(x).size(), (4, 1, 1))

def test_MaxPoolWithMask(self):
pool = MaxPoolWithMask()
x = torch.randn(5, 6, 7)
mask = (torch.randn(5, 6) > 0).long()
self.assertEqual(pool(x, mask).size(), (5, 7))

def test_KMaxPool(self):
k_pool = KMaxPool(k=3)
x = torch.randn(4, 5, 6)
self.assertEqual(k_pool(x).size(), (4, 15))

def test_AvgPool(self):
pool = AvgPool()
x = torch.randn(4, 5, 6)
self.assertEqual(pool(x).size(), (4, 5))

def test_AvgPoolWithMask(self):
pool = AvgPoolWithMask()
x = torch.randn(5, 6, 7)
mask = (torch.randn(5, 6) > 0).long()
self.assertEqual(pool(x, mask).size(), (5, 7))

Loading…
Cancel
Save