Browse Source

Merge pull request #7 from keezen/master

add conv and pooling module
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
ca598e161c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 0 deletions
  1. +22
    -0
      fastNLP/modules/convolution/AvgPool1d.py
  2. +28
    -0
      fastNLP/modules/convolution/Conv1d.py
  3. +23
    -0
      fastNLP/modules/convolution/MaxPool1d.py

+ 22
- 0
fastNLP/modules/convolution/AvgPool1d.py View File

@@ -0,0 +1,22 @@
# python: 3.6
# encoding: utf-8

import torch.nn as nn
# import torch.nn.functional as F


class AvgPool1d(nn.Module):
"""1-d average pooling module."""

def __init__(self, kernel_size, stride=None, padding=0,
ceil_mode=False, count_include_pad=True):
super(AvgPool1d, self).__init__()
self.pool = nn.AvgPool1d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)

def forward(self, x):
return self.pool(x)

+ 28
- 0
fastNLP/modules/convolution/Conv1d.py View File

@@ -0,0 +1,28 @@
# python: 3.6
# encoding: utf-8

import torch.nn as nn
# import torch.nn.functional as F


class Conv1d(nn.Module):
"""
Basic 1-d convolution module.
"""

def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1,
groups=1, bias=True):
super(Conv1d, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)

def forward(self, x):
return self.conv(x)

+ 23
- 0
fastNLP/modules/convolution/MaxPool1d.py View File

@@ -0,0 +1,23 @@
# python: 3.6
# encoding: utf-8

import torch.nn as nn
# import torch.nn.functional as F


class MaxPool1d(nn.Module):
"""1-d max-pooling module."""

def __init__(self, kernel_size, stride=None, padding=0,
dilation=1, return_indices=False, ceil_mode=False):
super(MaxPool1d, self).__init__()
self.maxpool = nn.MaxPool1d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode)

def forward(self, x):
return self.maxpool(x)

Loading…
Cancel
Save