|
|
@@ -3,6 +3,7 @@ |
|
|
|
__all__ = [ |
|
|
|
"MaxPool", |
|
|
|
"MaxPoolWithMask", |
|
|
|
"KMaxPool", |
|
|
|
"AvgPool", |
|
|
|
"AvgPoolWithMask" |
|
|
|
] |
|
|
@@ -27,7 +28,7 @@ class MaxPool(nn.Module): |
|
|
|
:param ceil_mode: |
|
|
|
""" |
|
|
|
super(MaxPool, self).__init__() |
|
|
|
assert (1 <= dimension) and (dimension <= 3) |
|
|
|
assert dimension in [1, 2, 3], f'Now we only support 1d, 2d, or 3d Pooling' |
|
|
|
self.dimension = dimension |
|
|
|
self.stride = stride |
|
|
|
self.padding = padding |
|
|
@@ -37,12 +38,12 @@ class MaxPool(nn.Module): |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
if self.dimension == 1: |
|
|
|
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] |
|
|
|
pooling = nn.MaxPool1d( |
|
|
|
stride=self.stride, padding=self.padding, dilation=self.dilation, |
|
|
|
kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), |
|
|
|
return_indices=False, ceil_mode=self.ceil_mode |
|
|
|
) |
|
|
|
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] |
|
|
|
elif self.dimension == 2: |
|
|
|
pooling = nn.MaxPool2d( |
|
|
|
stride=self.stride, padding=self.padding, dilation=self.dilation, |
|
|
@@ -50,7 +51,7 @@ class MaxPool(nn.Module): |
|
|
|
return_indices=False, ceil_mode=self.ceil_mode |
|
|
|
) |
|
|
|
else: |
|
|
|
pooling = nn.MaxPool2d( |
|
|
|
pooling = nn.MaxPool3d( |
|
|
|
stride=self.stride, padding=self.padding, dilation=self.dilation, |
|
|
|
kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), |
|
|
|
return_indices=False, ceil_mode=self.ceil_mode |
|
|
|