Browse Source

Merge pull request #44 from choosewhatulike/master

add MLP decoder
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
9b7ad27616
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 56 additions and 0 deletions
  1. +56
    -0
      fastNLP/modules/decoder/MLP.py

+ 56
- 0
fastNLP/modules/decoder/MLP.py View File

@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

class MLP(nn.Module):
def __init__(self, size_layer, num_class=2, activation='relu'):
"""Multilayer Perceptrons as a decoder

Args:
size_layer: list of int, define the size of MLP layers
num_class: int, num of class in output, should be 2 or the last layer's size
activation: str or function, the activation function for hidden layers
"""
super(MLP, self).__init__()
self.hiddens = nn.ModuleList()
self.output = None
for i in range(1, len(size_layer)):
if i + 1 == len(size_layer):
self.output = nn.Linear(size_layer[i-1], size_layer[i])
else:
self.hiddens.append(nn.Linear(size_layer[i-1], size_layer[i]))

if num_class == 2:
self.out_active = nn.LogSigmoid()
elif num_class == size_layer[-1]:
self.out_active = nn.LogSoftmax(dim=1)
else:
raise ValueError("should set output num_class correctly: {}".format(num_class))
actives = {
'relu': nn.ReLU(),
'tanh': nn.Tanh()
}
if activation in actives:
self.hidden_active = actives[activation]
elif isinstance(activation, callable):
self.hidden_active = activation
else:
raise ValueError("should set activation correctly: {}".format(activation))

def forward(self, x):
for layer in self.hiddens:
x = self.hidden_active(layer(x))
x = self.out_active(self.output(x))
return x



if __name__ == '__main__':
net1 = MLP([5,10,5])
net2 = MLP([5,10,5], 5)
for net in [net1, net2]:
x = torch.randn(5, 5)
y = net(x)
print(x)
print(y)

Loading…
Cancel
Save