From 8c13c28f0c23a5d8877145e02fc2aa2ff45bd66d Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:14:55 +0800 Subject: [PATCH 1/7] add nll loss --- fastNLP/core/loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastNLP/core/loss.py b/fastNLP/core/loss.py index 8d866bbf..8a0eedd7 100644 --- a/fastNLP/core/loss.py +++ b/fastNLP/core/loss.py @@ -37,5 +37,7 @@ class Loss(object): """ if loss_name == "cross_entropy": return torch.nn.CrossEntropyLoss() + elif loss_name == 'nll': + return torch.nn.NLLLoss() else: raise NotImplementedError From c24d01d50f78c96f1cf91a74d9cc0a8195899705 Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:15:25 +0800 Subject: [PATCH 2/7] fix a bug in label2index dict. --- fastNLP/core/preprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index f8142c36..e049c762 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -239,7 +239,7 @@ class SeqLabelPreprocess(BasePreprocess): label2index: dict of {str, int} """ # In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. - label2index = DEFAULT_WORD_TO_INDEX.copy() + label2index = {} # DEFAULT_WORD_TO_INDEX.copy() word2index = DEFAULT_WORD_TO_INDEX.copy() for example in data: for word, label in zip(example[0], example[1]): @@ -297,7 +297,7 @@ class ClassPreprocess(BasePreprocess): # build vocabulary from scratch if nothing exists word2index = DEFAULT_WORD_TO_INDEX.copy() - label2index = DEFAULT_WORD_TO_INDEX.copy() + label2index = {} # DEFAULT_WORD_TO_INDEX.copy() # collect every word and label for sent, label in data: From 5960aba9cb1fc3106c51200b965cd6579e04d2ab Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:16:36 +0800 Subject: [PATCH 3/7] change the code to do with sentence with padding tokens. --- fastNLP/modules/aggregation/self_attention.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/fastNLP/modules/aggregation/self_attention.py b/fastNLP/modules/aggregation/self_attention.py index aeaef4db..4155d708 100644 --- a/fastNLP/modules/aggregation/self_attention.py +++ b/fastNLP/modules/aggregation/self_attention.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn from torch.autograd import Variable +import torch.nn.functional as F +from fastNLP.modules.utils import initial_parameter class SelfAttention(nn.Module): """ Self Attention Module. @@ -13,13 +15,18 @@ class SelfAttention(nn.Module): num_vec: int, the number of encoded vectors """ - def __init__(self, input_size, dim=10, num_vec=10): + def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None): super(SelfAttention, self).__init__() - self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) - self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) + # self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) + # self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) + self.attention_hops = num_vec + + self.ws1 = nn.Linear(input_size, dim, bias=False) + self.ws2 = nn.Linear(dim, num_vec, bias=False) + self.drop = nn.Dropout(drop) self.softmax = nn.Softmax(dim=2) self.tanh = nn.Tanh() - + initial_parameter(self, initial_method) def penalization(self, A): """ compute the penalization term for attention module @@ -32,11 +39,33 @@ class SelfAttention(nn.Module): M = M.view(M.size(0), -1) return torch.sum(M ** 2, dim=1) - def forward(self, x): - inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) - A = self.softmax(torch.matmul(self.W_s2, inter)) - out = torch.matmul(A, x) - out = out.view(out.size(0), -1) - penalty = self.penalization(A) - return out, penalty + def forward(self, outp ,inp): + # the following code can not be use because some word are padding ,these is not such module! + + # inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # [] + # A = self.softmax(torch.matmul(self.W_s2, inter)) + # out = torch.matmul(A, x) + # out = out.view(out.size(0), -1) + # penalty = self.penalization(A) + # return out, penalty + outp = outp.contiguous() + size = outp.size() # [bsz, len, nhid] + + compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] + transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] + transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] + concatenated_inp = [transformed_inp for i in range(self.attention_hops)] + concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] + + hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] + attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] + attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len] + penalized_alphas = attention + ( + -10000 * (concatenated_inp == 0).float()) + # [bsz, hop, len] + [bsz, hop, len] + attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] + attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] + return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid] + + From a89875df1ee5fdb6034d8177e68795182d41e6d2 Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:17:22 +0800 Subject: [PATCH 4/7] add initial parameters --- fastNLP/modules/decoder/CRF.py | 7 ++-- fastNLP/modules/decoder/MLP.py | 6 +-- fastNLP/modules/encoder/char_embedding.py | 11 +++-- fastNLP/modules/encoder/conv.py | 6 ++- fastNLP/modules/encoder/conv_maxpool.py | 6 ++- fastNLP/modules/encoder/linear.py | 6 +-- fastNLP/modules/encoder/lstm.py | 8 ++-- fastNLP/modules/encoder/masked_rnn.py | 6 +-- fastNLP/modules/encoder/variational_rnn.py | 9 ++-- fastNLP/modules/utils.py | 49 +++++++++++++++++++++- 10 files changed, 85 insertions(+), 29 deletions(-) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index e6327ec0..991927da 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -1,6 +1,7 @@ import torch from torch import nn +from fastNLP.modules.utils import initial_parameter def log_sum_exp(x, dim=-1): max_value, _ = x.max(dim=dim, keepdim=True) @@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens): class ConditionalRandomField(nn.Module): - def __init__(self, tag_size, include_start_end_trans=True): + def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): """ :param tag_size: int, num of tags :param include_start_end_trans: bool, whether to include start/end tag @@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module): self.start_scores = nn.Parameter(torch.randn(tag_size)) self.end_scores = nn.Parameter(torch.randn(tag_size)) - self.reset_parameter() - + # self.reset_parameter() + initial_parameter(self, initial_method) def reset_parameter(self): nn.init.xavier_normal_(self.transition_m) if self.include_start_end_trans: diff --git a/fastNLP/modules/decoder/MLP.py b/fastNLP/modules/decoder/MLP.py index c70aa0e9..b8fb95f0 100644 --- a/fastNLP/modules/decoder/MLP.py +++ b/fastNLP/modules/decoder/MLP.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn - +from fastNLP.modules.utils import initial_parameter class MLP(nn.Module): - def __init__(self, size_layer, num_class=2, activation='relu'): + def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None): """Multilayer Perceptrons as a decoder Args: @@ -36,7 +36,7 @@ class MLP(nn.Module): self.hidden_active = activation else: raise ValueError("should set activation correctly: {}".format(activation)) - + initial_parameter(self, initial_method ) def forward(self, x): for layer in self.hiddens: x = self.hidden_active(layer(x)) diff --git a/fastNLP/modules/encoder/char_embedding.py b/fastNLP/modules/encoder/char_embedding.py index 72680e5b..1da63947 100644 --- a/fastNLP/modules/encoder/char_embedding.py +++ b/fastNLP/modules/encoder/char_embedding.py @@ -1,11 +1,12 @@ import torch import torch.nn.functional as F from torch import nn +# from torch.nn.init import xavier_uniform - +from fastNLP.modules.utils import initial_parameter class ConvCharEmbedding(nn.Module): - def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): + def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None): """ Character Level Word Embedding :param char_emb_size: the size of character level embedding. Default: 50 @@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module): nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) for i in range(len(kernels))]) + initial_parameter(self,initial_method) + def forward(self, x): """ :param x: [batch_size * sent_length, word_length, char_emb_size] @@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module): :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. """ - def __init__(self, char_emb_size=50, hidden_size=None): + def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None): super(LSTMCharEmbedding, self).__init__() self.hidden_size = char_emb_size if hidden_size is None else hidden_size @@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module): num_layers=1, bias=True, batch_first=True) - + initial_parameter(self, initial_method) def forward(self, x): """ :param x:[ n_batch*n_word, word_length, char_emb_size] diff --git a/fastNLP/modules/encoder/conv.py b/fastNLP/modules/encoder/conv.py index 06a31dd8..68536e5d 100644 --- a/fastNLP/modules/encoder/conv.py +++ b/fastNLP/modules/encoder/conv.py @@ -6,6 +6,7 @@ import torch.nn as nn from torch.nn.init import xavier_uniform_ # import torch.nn.functional as F +from fastNLP.modules.utils import initial_parameter class Conv(nn.Module): """ @@ -15,7 +16,7 @@ class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, - groups=1, bias=True, activation='relu'): + groups=1, bias=True, activation='relu',initial_method = None ): super(Conv, self).__init__() self.conv = nn.Conv1d( in_channels=in_channels, @@ -26,7 +27,7 @@ class Conv(nn.Module): dilation=dilation, groups=groups, bias=bias) - xavier_uniform_(self.conv.weight) + # xavier_uniform_(self.conv.weight) activations = { 'relu': nn.ReLU(), @@ -37,6 +38,7 @@ class Conv(nn.Module): raise Exception( 'Should choose activation function from: ' + ', '.join([x for x in activations])) + initial_parameter(self, initial_method) def forward(self, x): x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py index f666e7f9..7aa897cf 100644 --- a/fastNLP/modules/encoder/conv_maxpool.py +++ b/fastNLP/modules/encoder/conv_maxpool.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_ - +from fastNLP.modules.utils import initial_parameter class ConvMaxpool(nn.Module): """ @@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module): def __init__(self, in_channels, out_channels, kernel_sizes, stride=1, padding=0, dilation=1, - groups=1, bias=True, activation='relu'): + groups=1, bias=True, activation='relu',initial_method = None ): super(ConvMaxpool, self).__init__() # convolution @@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module): raise Exception( "Undefined activation function: choose from: relu") + initial_parameter(self, initial_method) + def forward(self, x): # [N,L,C] -> [N,C,L] x = torch.transpose(x, 1, 2) diff --git a/fastNLP/modules/encoder/linear.py b/fastNLP/modules/encoder/linear.py index 9582d9f9..a7c5f6c3 100644 --- a/fastNLP/modules/encoder/linear.py +++ b/fastNLP/modules/encoder/linear.py @@ -1,6 +1,6 @@ import torch.nn as nn - +from fastNLP.modules.utils import initial_parameter class Linear(nn.Module): """ Linear module @@ -12,10 +12,10 @@ class Linear(nn.Module): bidirectional : If True, becomes a bidirectional RNN """ - def __init__(self, input_size, output_size, bias=True): + def __init__(self, input_size, output_size, bias=True,initial_method = None ): super(Linear, self).__init__() self.linear = nn.Linear(input_size, output_size, bias) - + initial_parameter(self, initial_method) def forward(self, x): x = self.linear(x) return x diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index bed6c276..5af09f29 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -1,6 +1,6 @@ import torch.nn as nn - +from fastNLP.modules.utils import initial_parameter class Lstm(nn.Module): """ LSTM module @@ -13,11 +13,13 @@ class Lstm(nn.Module): bidirectional : If True, becomes a bidirectional RNN. Default: False. """ - def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False): + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None): super(Lstm, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, dropout=dropout, bidirectional=bidirectional) - + initial_parameter(self, initial_method) def forward(self, x): x, _ = self.lstm(x) return x +if __name__ == "__main__": + lstm = Lstm(10) diff --git a/fastNLP/modules/encoder/masked_rnn.py b/fastNLP/modules/encoder/masked_rnn.py index 76f828a9..c1ef15d0 100644 --- a/fastNLP/modules/encoder/masked_rnn.py +++ b/fastNLP/modules/encoder/masked_rnn.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - +from fastNLP.modules.utils import initial_parameter def MaskedRecurrent(reverse=False): def forward(input, hidden, cell, mask, train=True, dropout=0): """ @@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): class MaskedRNNBase(nn.Module): def __init__(self, Cell, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, - layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs): + layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): """ :param Cell: :param input_size: @@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module): cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) self.all_cells.append(cell) self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 - + initial_parameter(self, initial_method) def reset_parameters(self): for cell in self.all_cells: cell.reset_parameters() diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index b08bdd2d..fb75fabb 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend from torch.nn.parameter import Parameter +from fastNLP.modules.utils import initial_parameter def default_initializer(hidden_size): stdv = 1.0 / math.sqrt(hidden_size) @@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False): class VarMaskedRNNBase(nn.Module): def __init__(self, Cell, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, - dropout=(0, 0), bidirectional=False, initializer=None, **kwargs): + dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs): super(VarMaskedRNNBase, self).__init__() self.Cell = Cell @@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module): cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) self.all_cells.append(cell) self.add_module('cell%d' % (layer * num_directions + direction), cell) - + initial_parameter(self, initial_method) def reset_parameters(self): for cell in self.all_cells: cell.reset_parameters() @@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase): \end{array} """ - def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None): + def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None): super(VarFastLSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase): self.p_hidden = p_hidden self.noise_in = None self.noise_hidden = None - + initial_parameter(self, initial_method) def reset_parameters(self): for weight in self.parameters(): if weight.dim() == 1: diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 442944e7..22139668 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -2,8 +2,8 @@ from collections import defaultdict import numpy as np import torch - - +import torch.nn.init as init +import torch.nn as nn def mask_softmax(matrix, mask): if mask is None: result = torch.nn.functional.softmax(matrix, dim=-1) @@ -11,6 +11,51 @@ def mask_softmax(matrix, mask): raise NotImplementedError return result +def initial_parameter(net ,initial_method =None): + + if initial_method == 'xavier_uniform': + init_method = init.xavier_uniform_ + elif initial_method=='xavier_normal': + init_method = init.xavier_normal_ + elif initial_method == 'kaiming_normal' or initial_method =='msra': + init_method = init.kaiming_normal + elif initial_method == 'kaiming_uniform': + init_method = init.kaiming_normal + elif initial_method == 'orthogonal': + init_method = init.orthogonal_ + elif initial_method == 'sparse': + init_method = init.sparse_ + elif initial_method =='normal': + init_method = init.normal_ + elif initial_method =='uniform': + initial_method = init.uniform_ + else: + init_method = init.xavier_normal_ + def weights_init(m): + # classname = m.__class__.__name__ + if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn + if initial_method != None: + init_method(m.weight.data) + else: + init.xavier_normal_(m.weight.data) + init.normal_(m.bias.data) + elif isinstance(m, nn.LSTM): + for w in m.parameters(): + if len(w.data.size())>1: + init_method(w.data) # weight + else: + init.normal_(w.data) # bias + elif hasattr(m, 'weight') and m.weight.requires_grad: + init_method(m.weight.data) + else: + for w in m.parameters() : + if w.requires_grad: + if len(w.data.size())>1: + init_method(w.data) # weight + else: + init.normal_(w.data) # bias + # print("init else") + net.apply(weights_init) def seq_mask(seq_len, max_len): mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] From 7bea47681b3f0e7ee1f5946c27529b94c097d52e Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:18:51 +0800 Subject: [PATCH 5/7] set encoding model utf-8,otherwise in some computer it will compile failed. --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 25a645c5..a7b5dc46 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,18 @@ # coding=utf-8 from setuptools import setup, find_packages -with open('README.md') as f: +with open('README.md', encoding='utf-8') as f: readme = f.read() -with open('LICENSE') as f: +with open('LICENSE', encoding='utf-8') as f: license = f.read() -with open('requirements.txt') as f: +with open('requirements.txt', encoding='utf-8') as f: reqs = f.read() setup( name='fastNLP', - version='0.0.1', + version='0.0.3', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, license=license, From b3e8db74a665b358866f2f84e2c9060a475116e3 Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:19:56 +0800 Subject: [PATCH 6/7] add self_attention for yelp classification example. --- .../config.cfg | 13 +++ .../main.py | 80 +++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 reproduction/LSTM+self_attention_sentiment_analysis/config.cfg create mode 100644 reproduction/LSTM+self_attention_sentiment_analysis/main.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg b/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg new file mode 100644 index 00000000..2d31cd0d --- /dev/null +++ b/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg @@ -0,0 +1,13 @@ +[train] +epochs = 30 +batch_size = 32 +pickle_path = "./save/" +validate = true +save_best_dev = true +model_saved_path = "./save/" +rnn_hidden_units = 300 +word_emb_dim = 300 +use_crf = true +use_cuda = false +loss_func = "cross_entropy" +num_classes = 5 \ No newline at end of file diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/main.py b/reproduction/LSTM+self_attention_sentiment_analysis/main.py new file mode 100644 index 00000000..115d9a23 --- /dev/null +++ b/reproduction/LSTM+self_attention_sentiment_analysis/main.py @@ -0,0 +1,80 @@ + +import os + +import torch.nn.functional as F + +from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader +from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader +from fastNLP.loader.config_loader import ConfigSection +from fastNLP.loader.config_loader import ConfigLoader + +from fastNLP.models.base_model import BaseModel + +from fastNLP.core.preprocess import ClassPreprocess as Preprocess +from fastNLP.core.trainer import ClassificationTrainer + +from fastNLP.modules.encoder.embedding import Embedding as Embedding +from fastNLP.modules.encoder.lstm import Lstm +from fastNLP.modules.aggregation.self_attention import SelfAttention +from fastNLP.modules.decoder.MLP import MLP + + +train_data_path = 'small_train_data.txt' +dev_data_path = 'small_dev_data.txt' +# emb_path = 'glove.txt' + +lstm_hidden_size = 300 +embeding_size = 300 +attention_unit = 350 +attention_hops = 10 +class_num = 5 +nfc = 3000 +### data load ### +train_dataset = Dataset_loader(train_data_path) +train_data = train_dataset.load() + +dev_args = Dataset_loader(dev_data_path) +dev_data = dev_args.load() + +###### preprocess #### +preprocess = Preprocess() +word2index, label2index = preprocess.build_dict(train_data) +train_data, dev_data = preprocess.run(train_data, dev_data) + + + +# emb = EmbedLoader(emb_path) +# embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index) +### construct vocab ### + +class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): + def __init__(self, args=None): + super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() + self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) + self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) + self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) + self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,) + def forward(self,x): + x_emb = self.embedding(x) + output = self.lstm(x_emb) + after_attention, penalty = self.attention(output,x) + after_attention =after_attention.view(after_attention.size(0),-1) + output = self.mlp(after_attention) + return output + + def loss(self, predict, ground_truth): + print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size()))) + print(ground_truth) + return F.cross_entropy(predict, ground_truth) + +train_args = ConfigSection() +ConfigLoader("good path").load_config('config.cfg',{"train": train_args}) +train_args['vocab'] = len(word2index) + + +trainer = ClassificationTrainer(**train_args.data) + +# for k in train_args.__dict__.keys(): +# print(k, train_args[k]) +model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args) +trainer.train(model,train_data , dev_data) From 5c671078b6a0039b8071cd9c3984184831cb8989 Mon Sep 17 00:00:00 2001 From: Yige XU Date: Sat, 15 Sep 2018 18:13:17 +0800 Subject: [PATCH 7/7] Update preprocess.py --- fastNLP/core/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index e049c762..5e4122f9 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -239,7 +239,7 @@ class SeqLabelPreprocess(BasePreprocess): label2index: dict of {str, int} """ # In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. - label2index = {} # DEFAULT_WORD_TO_INDEX.copy() + label2index = DEFAULT_WORD_TO_INDEX.copy() word2index = DEFAULT_WORD_TO_INDEX.copy() for example in data: for word, label in zip(example[0], example[1]):