rename and add kmax pooling moduletags/v0.1.0
@@ -1,22 +0,0 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch.nn as nn | |||||
# import torch.nn.functional as F | |||||
class AvgPool1d(nn.Module): | |||||
"""1-d average pooling module.""" | |||||
def __init__(self, kernel_size, stride=None, padding=0, | |||||
ceil_mode=False, count_include_pad=True): | |||||
super(AvgPool1d, self).__init__() | |||||
self.pool = nn.AvgPool1d( | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
ceil_mode=ceil_mode, | |||||
count_include_pad=count_include_pad) | |||||
def forward(self, x): | |||||
return self.pool(x) |
@@ -1,23 +0,0 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch.nn as nn | |||||
# import torch.nn.functional as F | |||||
class MaxPool1d(nn.Module): | |||||
"""1-d max-pooling module.""" | |||||
def __init__(self, kernel_size, stride=None, padding=0, | |||||
dilation=1, return_indices=False, ceil_mode=False): | |||||
super(MaxPool1d, self).__init__() | |||||
self.maxpool = nn.MaxPool1d( | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
return_indices=return_indices, | |||||
ceil_mode=ceil_mode) | |||||
def forward(self, x): | |||||
return self.maxpool(x) |
@@ -0,0 +1,24 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class AvgPool(nn.Module): | |||||
"""1-d average pooling module.""" | |||||
def __init__(self, stride=None, padding=0): | |||||
super(AvgPool, self).__init__() | |||||
self.stride = stride | |||||
self.padding = padding | |||||
def forward(self, x): | |||||
# [N,C,L] -> [N,C] | |||||
kernel_size = x.size(2) | |||||
x = F.max_pool1d( | |||||
input=x, | |||||
kernel_size=kernel_size, | |||||
stride=self.stride, | |||||
padding=self.padding) | |||||
return x.squeeze(dim=-1) |
@@ -5,7 +5,7 @@ import torch.nn as nn | |||||
# import torch.nn.functional as F | # import torch.nn.functional as F | ||||
class Conv1d(nn.Module): | |||||
class Conv(nn.Module): | |||||
""" | """ | ||||
Basic 1-d convolution module. | Basic 1-d convolution module. | ||||
""" | """ | ||||
@@ -13,7 +13,7 @@ class Conv1d(nn.Module): | |||||
def __init__(self, in_channels, out_channels, kernel_size, | def __init__(self, in_channels, out_channels, kernel_size, | ||||
stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
groups=1, bias=True): | groups=1, bias=True): | ||||
super(Conv1d, self).__init__() | |||||
super(Conv, self).__init__() | |||||
self.conv = nn.Conv1d( | self.conv = nn.Conv1d( | ||||
in_channels=in_channels, | in_channels=in_channels, | ||||
out_channels=out_channels, | out_channels=out_channels, | ||||
@@ -25,4 +25,4 @@ class Conv1d(nn.Module): | |||||
bias=bias) | bias=bias) | ||||
def forward(self, x): | def forward(self, x): | ||||
return self.conv(x) | |||||
return self.conv(x) # [N,C,L] |
@@ -0,0 +1,20 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch | |||||
import torch.nn as nn | |||||
# import torch.nn.functional as F | |||||
class KMaxPool(nn.Module): | |||||
"""K max-pooling module.""" | |||||
def __init__(self, k): | |||||
super(KMaxPool, self).__init__() | |||||
self.k = k | |||||
def forward(self, x): | |||||
# [N,C,L] -> [N,C*k] | |||||
x, index = torch.topk(x, self.k, dim=-1, sorted=False) | |||||
x = torch.reshape(x, (x.size(0), -1)) | |||||
return x |
@@ -0,0 +1,26 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class MaxPool(nn.Module): | |||||
"""1-d max-pooling module.""" | |||||
def __init__(self, stride=None, padding=0, dilation=1): | |||||
super(MaxPool, self).__init__() | |||||
self.stride = stride | |||||
self.padding = padding | |||||
self.dilation = dilation | |||||
def forward(self, x): | |||||
# [N,C,L] -> [N,C] | |||||
kernel_size = x.size(2) | |||||
x = F.max_pool1d( | |||||
input=x, | |||||
kernel_size=kernel_size, | |||||
stride=self.stride, | |||||
padding=self.padding, | |||||
dilation=self.dilation) | |||||
return x.squeeze(dim=-1) |