Browse Source

add kmax pooling module

tags/v0.1.0
Ke Zhen 6 years ago
parent
commit
7b7826544e
3 changed files with 20 additions and 0 deletions
  1. +0
    -0
      fastNLP/modules/convolution/avg_pool.py
  2. +20
    -0
      fastNLP/modules/convolution/kmax_pool.py
  3. +0
    -0
      fastNLP/modules/convolution/max_pool.py

fastNLP/modules/convolution/AvgPool.py → fastNLP/modules/convolution/avg_pool.py View File


+ 20
- 0
fastNLP/modules/convolution/kmax_pool.py View File

@@ -0,0 +1,20 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn
# import torch.nn.functional as F


class KMaxPool(nn.Module):
"""K max-pooling module."""

def __init__(self, k):
super(KMaxPool, self).__init__()
self.k = k

def forward(self, x):
# [N,C,L] -> [N,C*k]
x, index = torch.topk(x, self.k, dim=-1, sorted=False)
x = torch.reshape(x, (x.size(0), -1))
return x

fastNLP/modules/convolution/MaxPool.py → fastNLP/modules/convolution/max_pool.py View File


Loading…
Cancel
Save