|
|
@@ -25,3 +25,14 @@ class MaxPool(nn.Module): |
|
|
|
padding=self.padding, |
|
|
|
dilation=self.dilation) |
|
|
|
return x.squeeze(dim=-1) # [N,C,1] -> [N,C] |
|
|
|
|
|
|
|
|
|
|
|
class MaxPoolWithMask(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super(MaxPoolWithMask, self).__init__() |
|
|
|
self.inf = 10e12 |
|
|
|
|
|
|
|
def forward(self, tensor, mask, dim=0): |
|
|
|
masks = mask.view(mask.size(0), mask.size(1), -1) |
|
|
|
masks = masks.expand(-1, -1, tensor.size(2)).float() |
|
|
|
return torch.max(tensor + masks.le(0.5).float() * -self.inf, dim=dim) |