|
|
@@ -5,7 +5,7 @@ import torch.nn as nn |
|
|
|
# import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
class Conv1d(nn.Module): |
|
|
|
class Conv(nn.Module): |
|
|
|
""" |
|
|
|
Basic 1-d convolution module. |
|
|
|
""" |
|
|
@@ -13,7 +13,7 @@ class Conv1d(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
|
|
stride=1, padding=0, dilation=1, |
|
|
|
groups=1, bias=True): |
|
|
|
super(Conv1d, self).__init__() |
|
|
|
super(Conv, self).__init__() |
|
|
|
self.conv = nn.Conv1d( |
|
|
|
in_channels=in_channels, |
|
|
|
out_channels=out_channels, |
|
|
@@ -25,4 +25,4 @@ class Conv1d(nn.Module): |
|
|
|
bias=bias) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return self.conv(x) |
|
|
|
return self.conv(x) # [N,C,L] |