@@ -1,7 +1,10 @@ | |||||
from .max_pool import MaxPool | from .max_pool import MaxPool | ||||
from .max_pool import MaxPoolWithMask | |||||
from .avg_pool import AvgPool | from .avg_pool import AvgPool | ||||
from .avg_pool import MeanPoolWithMask | |||||
from .kmax_pool import KMaxPool | from .kmax_pool import KMaxPool | ||||
from .attention import Attention | from .attention import Attention | ||||
from .attention import Bi_Attention | |||||
from .self_attention import SelfAttention | from .self_attention import SelfAttention | ||||
@@ -1,6 +1,7 @@ | |||||
# python: 3.6 | # python: 3.6 | ||||
# encoding: utf-8 | # encoding: utf-8 | ||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
@@ -22,3 +23,14 @@ class AvgPool(nn.Module): | |||||
stride=self.stride, | stride=self.stride, | ||||
padding=self.padding) | padding=self.padding) | ||||
return x.squeeze(dim=-1) | return x.squeeze(dim=-1) | ||||
class MeanPoolWithMask(nn.Module): | |||||
def __init__(self): | |||||
super(MeanPoolWithMask, self).__init__() | |||||
self.inf = 10e12 | |||||
def forward(self, tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||||
return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) | |||||
@@ -25,3 +25,14 @@ class MaxPool(nn.Module): | |||||
padding=self.padding, | padding=self.padding, | ||||
dilation=self.dilation) | dilation=self.dilation) | ||||
return x.squeeze(dim=-1) # [N,C,1] -> [N,C] | return x.squeeze(dim=-1) # [N,C,1] -> [N,C] | ||||
class MaxPoolWithMask(nn.Module): | |||||
def __init__(self): | |||||
super(MaxPoolWithMask, self).__init__() | |||||
self.inf = 10e12 | |||||
def forward(self, tensor, mask, dim=0): | |||||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||||
masks = masks.expand(-1, -1, tensor.size(2)).float() | |||||
return torch.max(tensor + masks.le(0.5).float() * -self.inf, dim=dim) |