| @@ -37,5 +37,7 @@ class Loss(object): | |||||
| """ | """ | ||||
| if loss_name == "cross_entropy": | if loss_name == "cross_entropy": | ||||
| return torch.nn.CrossEntropyLoss() | return torch.nn.CrossEntropyLoss() | ||||
| elif loss_name == 'nll': | |||||
| return torch.nn.NLLLoss() | |||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -291,9 +291,11 @@ class BasePreprocess(object): | |||||
| class SeqLabelPreprocess(BasePreprocess): | class SeqLabelPreprocess(BasePreprocess): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(SeqLabelPreprocess, self).__init__() | super(SeqLabelPreprocess, self).__init__() | ||||
| class ClassPreprocess(BasePreprocess): | class ClassPreprocess(BasePreprocess): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ClassPreprocess, self).__init__() | super(ClassPreprocess, self).__init__() | ||||
| @@ -1,8 +1,10 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from torch.autograd import Variable | from torch.autograd import Variable | ||||
| import torch.nn.functional as F | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class SelfAttention(nn.Module): | class SelfAttention(nn.Module): | ||||
| """ | """ | ||||
| Self Attention Module. | Self Attention Module. | ||||
| @@ -13,13 +15,18 @@ class SelfAttention(nn.Module): | |||||
| num_vec: int, the number of encoded vectors | num_vec: int, the number of encoded vectors | ||||
| """ | """ | ||||
| def __init__(self, input_size, dim=10, num_vec=10): | |||||
| def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None): | |||||
| super(SelfAttention, self).__init__() | super(SelfAttention, self).__init__() | ||||
| self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| # self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| # self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| self.attention_hops = num_vec | |||||
| self.ws1 = nn.Linear(input_size, dim, bias=False) | |||||
| self.ws2 = nn.Linear(dim, num_vec, bias=False) | |||||
| self.drop = nn.Dropout(drop) | |||||
| self.softmax = nn.Softmax(dim=2) | self.softmax = nn.Softmax(dim=2) | ||||
| self.tanh = nn.Tanh() | self.tanh = nn.Tanh() | ||||
| initial_parameter(self, initial_method) | |||||
| def penalization(self, A): | def penalization(self, A): | ||||
| """ | """ | ||||
| compute the penalization term for attention module | compute the penalization term for attention module | ||||
| @@ -32,11 +39,33 @@ class SelfAttention(nn.Module): | |||||
| M = M.view(M.size(0), -1) | M = M.view(M.size(0), -1) | ||||
| return torch.sum(M ** 2, dim=1) | return torch.sum(M ** 2, dim=1) | ||||
| def forward(self, x): | |||||
| inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) | |||||
| A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| out = torch.matmul(A, x) | |||||
| out = out.view(out.size(0), -1) | |||||
| penalty = self.penalization(A) | |||||
| return out, penalty | |||||
| def forward(self, outp ,inp): | |||||
| # the following code can not be use because some word are padding ,these is not such module! | |||||
| # inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # [] | |||||
| # A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| # out = torch.matmul(A, x) | |||||
| # out = out.view(out.size(0), -1) | |||||
| # penalty = self.penalization(A) | |||||
| # return out, penalty | |||||
| outp = outp.contiguous() | |||||
| size = outp.size() # [bsz, len, nhid] | |||||
| compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] | |||||
| transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] | |||||
| transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] | |||||
| concatenated_inp = [transformed_inp for i in range(self.attention_hops)] | |||||
| concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] | |||||
| hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] | |||||
| attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] | |||||
| attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len] | |||||
| penalized_alphas = attention + ( | |||||
| -10000 * (concatenated_inp == 0).float()) | |||||
| # [bsz, hop, len] + [bsz, hop, len] | |||||
| attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] | |||||
| attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] | |||||
| return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid] | |||||
| @@ -1,6 +1,7 @@ | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def log_sum_exp(x, dim=-1): | def log_sum_exp(x, dim=-1): | ||||
| max_value, _ = x.max(dim=dim, keepdim=True) | max_value, _ = x.max(dim=dim, keepdim=True) | ||||
| @@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
| class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
| def __init__(self, tag_size, include_start_end_trans=True): | |||||
| def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
| """ | """ | ||||
| :param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
| :param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
| @@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module): | |||||
| self.start_scores = nn.Parameter(torch.randn(tag_size)) | self.start_scores = nn.Parameter(torch.randn(tag_size)) | ||||
| self.end_scores = nn.Parameter(torch.randn(tag_size)) | self.end_scores = nn.Parameter(torch.randn(tag_size)) | ||||
| self.reset_parameter() | |||||
| # self.reset_parameter() | |||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameter(self): | def reset_parameter(self): | ||||
| nn.init.xavier_normal_(self.transition_m) | nn.init.xavier_normal_(self.transition_m) | ||||
| if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
| @@ -1,8 +1,8 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class MLP(nn.Module): | class MLP(nn.Module): | ||||
| def __init__(self, size_layer, num_class=2, activation='relu'): | |||||
| def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None): | |||||
| """Multilayer Perceptrons as a decoder | """Multilayer Perceptrons as a decoder | ||||
| Args: | Args: | ||||
| @@ -36,7 +36,7 @@ class MLP(nn.Module): | |||||
| self.hidden_active = activation | self.hidden_active = activation | ||||
| else: | else: | ||||
| raise ValueError("should set activation correctly: {}".format(activation)) | raise ValueError("should set activation correctly: {}".format(activation)) | ||||
| initial_parameter(self, initial_method ) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| for layer in self.hiddens: | for layer in self.hiddens: | ||||
| x = self.hidden_active(layer(x)) | x = self.hidden_active(layer(x)) | ||||
| @@ -1,11 +1,12 @@ | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import nn | from torch import nn | ||||
| # from torch.nn.init import xavier_uniform | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class ConvCharEmbedding(nn.Module): | class ConvCharEmbedding(nn.Module): | ||||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): | |||||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None): | |||||
| """ | """ | ||||
| Character Level Word Embedding | Character Level Word Embedding | ||||
| :param char_emb_size: the size of character level embedding. Default: 50 | :param char_emb_size: the size of character level embedding. Default: 50 | ||||
| @@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module): | |||||
| nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | ||||
| for i in range(len(kernels))]) | for i in range(len(kernels))]) | ||||
| initial_parameter(self,initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| """ | """ | ||||
| :param x: [batch_size * sent_length, word_length, char_emb_size] | :param x: [batch_size * sent_length, word_length, char_emb_size] | ||||
| @@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
| :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | ||||
| """ | """ | ||||
| def __init__(self, char_emb_size=50, hidden_size=None): | |||||
| def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None): | |||||
| super(LSTMCharEmbedding, self).__init__() | super(LSTMCharEmbedding, self).__init__() | ||||
| self.hidden_size = char_emb_size if hidden_size is None else hidden_size | self.hidden_size = char_emb_size if hidden_size is None else hidden_size | ||||
| @@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
| num_layers=1, | num_layers=1, | ||||
| bias=True, | bias=True, | ||||
| batch_first=True) | batch_first=True) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| """ | """ | ||||
| :param x:[ n_batch*n_word, word_length, char_emb_size] | :param x:[ n_batch*n_word, word_length, char_emb_size] | ||||
| @@ -6,6 +6,7 @@ import torch.nn as nn | |||||
| from torch.nn.init import xavier_uniform_ | from torch.nn.init import xavier_uniform_ | ||||
| # import torch.nn.functional as F | # import torch.nn.functional as F | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Conv(nn.Module): | class Conv(nn.Module): | ||||
| """ | """ | ||||
| @@ -15,7 +16,7 @@ class Conv(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, kernel_size, | def __init__(self, in_channels, out_channels, kernel_size, | ||||
| stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
| groups=1, bias=True, activation='relu'): | |||||
| groups=1, bias=True, activation='relu',initial_method = None ): | |||||
| super(Conv, self).__init__() | super(Conv, self).__init__() | ||||
| self.conv = nn.Conv1d( | self.conv = nn.Conv1d( | ||||
| in_channels=in_channels, | in_channels=in_channels, | ||||
| @@ -26,7 +27,7 @@ class Conv(nn.Module): | |||||
| dilation=dilation, | dilation=dilation, | ||||
| groups=groups, | groups=groups, | ||||
| bias=bias) | bias=bias) | ||||
| xavier_uniform_(self.conv.weight) | |||||
| # xavier_uniform_(self.conv.weight) | |||||
| activations = { | activations = { | ||||
| 'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
| @@ -37,6 +38,7 @@ class Conv(nn.Module): | |||||
| raise Exception( | raise Exception( | ||||
| 'Should choose activation function from: ' + | 'Should choose activation function from: ' + | ||||
| ', '.join([x for x in activations])) | ', '.join([x for x in activations])) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | ||||
| @@ -5,7 +5,7 @@ import torch | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch.nn.init import xavier_uniform_ | from torch.nn.init import xavier_uniform_ | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
| """ | """ | ||||
| @@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, kernel_sizes, | def __init__(self, in_channels, out_channels, kernel_sizes, | ||||
| stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
| groups=1, bias=True, activation='relu'): | |||||
| groups=1, bias=True, activation='relu',initial_method = None ): | |||||
| super(ConvMaxpool, self).__init__() | super(ConvMaxpool, self).__init__() | ||||
| # convolution | # convolution | ||||
| @@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module): | |||||
| raise Exception( | raise Exception( | ||||
| "Undefined activation function: choose from: relu") | "Undefined activation function: choose from: relu") | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| # [N,L,C] -> [N,C,L] | # [N,L,C] -> [N,C,L] | ||||
| x = torch.transpose(x, 1, 2) | x = torch.transpose(x, 1, 2) | ||||
| @@ -1,6 +1,6 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Linear(nn.Module): | class Linear(nn.Module): | ||||
| """ | """ | ||||
| Linear module | Linear module | ||||
| @@ -12,10 +12,10 @@ class Linear(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN | bidirectional : If True, becomes a bidirectional RNN | ||||
| """ | """ | ||||
| def __init__(self, input_size, output_size, bias=True): | |||||
| def __init__(self, input_size, output_size, bias=True,initial_method = None ): | |||||
| super(Linear, self).__init__() | super(Linear, self).__init__() | ||||
| self.linear = nn.Linear(input_size, output_size, bias) | self.linear = nn.Linear(input_size, output_size, bias) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = self.linear(x) | x = self.linear(x) | ||||
| return x | return x | ||||
| @@ -1,6 +1,6 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Lstm(nn.Module): | class Lstm(nn.Module): | ||||
| """ | """ | ||||
| LSTM module | LSTM module | ||||
| @@ -13,11 +13,13 @@ class Lstm(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | bidirectional : If True, becomes a bidirectional RNN. Default: False. | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False): | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None): | |||||
| super(Lstm, self).__init__() | super(Lstm, self).__init__() | ||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
| dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x, _ = self.lstm(x) | x, _ = self.lstm(x) | ||||
| return x | return x | ||||
| if __name__ == "__main__": | |||||
| lstm = Lstm(10) | |||||
| @@ -4,7 +4,7 @@ import torch | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def MaskedRecurrent(reverse=False): | def MaskedRecurrent(reverse=False): | ||||
| def forward(input, hidden, cell, mask, train=True, dropout=0): | def forward(input, hidden, cell, mask, train=True, dropout=0): | ||||
| """ | """ | ||||
| @@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): | |||||
| class MaskedRNNBase(nn.Module): | class MaskedRNNBase(nn.Module): | ||||
| def __init__(self, Cell, input_size, hidden_size, | def __init__(self, Cell, input_size, hidden_size, | ||||
| num_layers=1, bias=True, batch_first=False, | num_layers=1, bias=True, batch_first=False, | ||||
| layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs): | |||||
| layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): | |||||
| """ | """ | ||||
| :param Cell: | :param Cell: | ||||
| :param input_size: | :param input_size: | ||||
| @@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module): | |||||
| cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | ||||
| self.all_cells.append(cell) | self.all_cells.append(cell) | ||||
| self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for cell in self.all_cells: | for cell in self.all_cells: | ||||
| cell.reset_parameters() | cell.reset_parameters() | ||||
| @@ -6,6 +6,7 @@ import torch.nn.functional as F | |||||
| from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | ||||
| from torch.nn.parameter import Parameter | from torch.nn.parameter import Parameter | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def default_initializer(hidden_size): | def default_initializer(hidden_size): | ||||
| stdv = 1.0 / math.sqrt(hidden_size) | stdv = 1.0 / math.sqrt(hidden_size) | ||||
| @@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False): | |||||
| class VarMaskedRNNBase(nn.Module): | class VarMaskedRNNBase(nn.Module): | ||||
| def __init__(self, Cell, input_size, hidden_size, | def __init__(self, Cell, input_size, hidden_size, | ||||
| num_layers=1, bias=True, batch_first=False, | num_layers=1, bias=True, batch_first=False, | ||||
| dropout=(0, 0), bidirectional=False, initializer=None, **kwargs): | |||||
| dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs): | |||||
| super(VarMaskedRNNBase, self).__init__() | super(VarMaskedRNNBase, self).__init__() | ||||
| self.Cell = Cell | self.Cell = Cell | ||||
| @@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module): | |||||
| cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | ||||
| self.all_cells.append(cell) | self.all_cells.append(cell) | ||||
| self.add_module('cell%d' % (layer * num_directions + direction), cell) | self.add_module('cell%d' % (layer * num_directions + direction), cell) | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for cell in self.all_cells: | for cell in self.all_cells: | ||||
| cell.reset_parameters() | cell.reset_parameters() | ||||
| @@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase): | |||||
| \end{array} | \end{array} | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None): | |||||
| def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None): | |||||
| super(VarFastLSTMCell, self).__init__() | super(VarFastLSTMCell, self).__init__() | ||||
| self.input_size = input_size | self.input_size = input_size | ||||
| self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
| @@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase): | |||||
| self.p_hidden = p_hidden | self.p_hidden = p_hidden | ||||
| self.noise_in = None | self.noise_in = None | ||||
| self.noise_hidden = None | self.noise_hidden = None | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for weight in self.parameters(): | for weight in self.parameters(): | ||||
| if weight.dim() == 1: | if weight.dim() == 1: | ||||
| @@ -2,8 +2,8 @@ from collections import defaultdict | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn.init as init | |||||
| import torch.nn as nn | |||||
| def mask_softmax(matrix, mask): | def mask_softmax(matrix, mask): | ||||
| if mask is None: | if mask is None: | ||||
| result = torch.nn.functional.softmax(matrix, dim=-1) | result = torch.nn.functional.softmax(matrix, dim=-1) | ||||
| @@ -11,6 +11,51 @@ def mask_softmax(matrix, mask): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| return result | return result | ||||
| def initial_parameter(net ,initial_method =None): | |||||
| if initial_method == 'xavier_uniform': | |||||
| init_method = init.xavier_uniform_ | |||||
| elif initial_method=='xavier_normal': | |||||
| init_method = init.xavier_normal_ | |||||
| elif initial_method == 'kaiming_normal' or initial_method =='msra': | |||||
| init_method = init.kaiming_normal | |||||
| elif initial_method == 'kaiming_uniform': | |||||
| init_method = init.kaiming_normal | |||||
| elif initial_method == 'orthogonal': | |||||
| init_method = init.orthogonal_ | |||||
| elif initial_method == 'sparse': | |||||
| init_method = init.sparse_ | |||||
| elif initial_method =='normal': | |||||
| init_method = init.normal_ | |||||
| elif initial_method =='uniform': | |||||
| initial_method = init.uniform_ | |||||
| else: | |||||
| init_method = init.xavier_normal_ | |||||
| def weights_init(m): | |||||
| # classname = m.__class__.__name__ | |||||
| if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn | |||||
| if initial_method != None: | |||||
| init_method(m.weight.data) | |||||
| else: | |||||
| init.xavier_normal_(m.weight.data) | |||||
| init.normal_(m.bias.data) | |||||
| elif isinstance(m, nn.LSTM): | |||||
| for w in m.parameters(): | |||||
| if len(w.data.size())>1: | |||||
| init_method(w.data) # weight | |||||
| else: | |||||
| init.normal_(w.data) # bias | |||||
| elif hasattr(m, 'weight') and m.weight.requires_grad: | |||||
| init_method(m.weight.data) | |||||
| else: | |||||
| for w in m.parameters() : | |||||
| if w.requires_grad: | |||||
| if len(w.data.size())>1: | |||||
| init_method(w.data) # weight | |||||
| else: | |||||
| init.normal_(w.data) # bias | |||||
| # print("init else") | |||||
| net.apply(weights_init) | |||||
| def seq_mask(seq_len, max_len): | def seq_mask(seq_len, max_len): | ||||
| mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | ||||
| @@ -0,0 +1,13 @@ | |||||
| [train] | |||||
| epochs = 30 | |||||
| batch_size = 32 | |||||
| pickle_path = "./save/" | |||||
| validate = true | |||||
| save_best_dev = true | |||||
| model_saved_path = "./save/" | |||||
| rnn_hidden_units = 300 | |||||
| word_emb_dim = 300 | |||||
| use_crf = true | |||||
| use_cuda = false | |||||
| loss_func = "cross_entropy" | |||||
| num_classes = 5 | |||||
| @@ -0,0 +1,80 @@ | |||||
| import os | |||||
| import torch.nn.functional as F | |||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||||
| from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader | |||||
| from fastNLP.loader.config_loader import ConfigSection | |||||
| from fastNLP.loader.config_loader import ConfigLoader | |||||
| from fastNLP.models.base_model import BaseModel | |||||
| from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
| from fastNLP.core.trainer import ClassificationTrainer | |||||
| from fastNLP.modules.encoder.embedding import Embedding as Embedding | |||||
| from fastNLP.modules.encoder.lstm import Lstm | |||||
| from fastNLP.modules.aggregation.self_attention import SelfAttention | |||||
| from fastNLP.modules.decoder.MLP import MLP | |||||
| train_data_path = 'small_train_data.txt' | |||||
| dev_data_path = 'small_dev_data.txt' | |||||
| # emb_path = 'glove.txt' | |||||
| lstm_hidden_size = 300 | |||||
| embeding_size = 300 | |||||
| attention_unit = 350 | |||||
| attention_hops = 10 | |||||
| class_num = 5 | |||||
| nfc = 3000 | |||||
| ### data load ### | |||||
| train_dataset = Dataset_loader(train_data_path) | |||||
| train_data = train_dataset.load() | |||||
| dev_args = Dataset_loader(dev_data_path) | |||||
| dev_data = dev_args.load() | |||||
| ###### preprocess #### | |||||
| preprocess = Preprocess() | |||||
| word2index, label2index = preprocess.build_dict(train_data) | |||||
| train_data, dev_data = preprocess.run(train_data, dev_data) | |||||
| # emb = EmbedLoader(emb_path) | |||||
| # embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index) | |||||
| ### construct vocab ### | |||||
| class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): | |||||
| def __init__(self, args=None): | |||||
| super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | |||||
| self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | |||||
| self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) | |||||
| self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | |||||
| self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,) | |||||
| def forward(self,x): | |||||
| x_emb = self.embedding(x) | |||||
| output = self.lstm(x_emb) | |||||
| after_attention, penalty = self.attention(output,x) | |||||
| after_attention =after_attention.view(after_attention.size(0),-1) | |||||
| output = self.mlp(after_attention) | |||||
| return output | |||||
| def loss(self, predict, ground_truth): | |||||
| print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size()))) | |||||
| print(ground_truth) | |||||
| return F.cross_entropy(predict, ground_truth) | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("good path").load_config('config.cfg',{"train": train_args}) | |||||
| train_args['vocab'] = len(word2index) | |||||
| trainer = ClassificationTrainer(**train_args.data) | |||||
| # for k in train_args.__dict__.keys(): | |||||
| # print(k, train_args[k]) | |||||
| model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args) | |||||
| trainer.train(model,train_data , dev_data) | |||||
| @@ -2,18 +2,18 @@ | |||||
| # coding=utf-8 | # coding=utf-8 | ||||
| from setuptools import setup, find_packages | from setuptools import setup, find_packages | ||||
| with open('README.md') as f: | |||||
| with open('README.md', encoding='utf-8') as f: | |||||
| readme = f.read() | readme = f.read() | ||||
| with open('LICENSE') as f: | |||||
| with open('LICENSE', encoding='utf-8') as f: | |||||
| license = f.read() | license = f.read() | ||||
| with open('requirements.txt') as f: | |||||
| with open('requirements.txt', encoding='utf-8') as f: | |||||
| reqs = f.read() | reqs = f.read() | ||||
| setup( | setup( | ||||
| name='fastNLP', | name='fastNLP', | ||||
| version='0.0.1', | |||||
| version='0.0.3', | |||||
| description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
| long_description=readme, | long_description=readme, | ||||
| license=license, | license=license, | ||||