From b3e8db74a665b358866f2f84e2c9060a475116e3 Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:19:56 +0800 Subject: [PATCH] add self_attention for yelp classification example. --- .../config.cfg | 13 +++ .../main.py | 80 +++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 reproduction/LSTM+self_attention_sentiment_analysis/config.cfg create mode 100644 reproduction/LSTM+self_attention_sentiment_analysis/main.py diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg b/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg new file mode 100644 index 00000000..2d31cd0d --- /dev/null +++ b/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg @@ -0,0 +1,13 @@ +[train] +epochs = 30 +batch_size = 32 +pickle_path = "./save/" +validate = true +save_best_dev = true +model_saved_path = "./save/" +rnn_hidden_units = 300 +word_emb_dim = 300 +use_crf = true +use_cuda = false +loss_func = "cross_entropy" +num_classes = 5 \ No newline at end of file diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/main.py b/reproduction/LSTM+self_attention_sentiment_analysis/main.py new file mode 100644 index 00000000..115d9a23 --- /dev/null +++ b/reproduction/LSTM+self_attention_sentiment_analysis/main.py @@ -0,0 +1,80 @@ + +import os + +import torch.nn.functional as F + +from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader +from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader +from fastNLP.loader.config_loader import ConfigSection +from fastNLP.loader.config_loader import ConfigLoader + +from fastNLP.models.base_model import BaseModel + +from fastNLP.core.preprocess import ClassPreprocess as Preprocess +from fastNLP.core.trainer import ClassificationTrainer + +from fastNLP.modules.encoder.embedding import Embedding as Embedding +from fastNLP.modules.encoder.lstm import Lstm +from fastNLP.modules.aggregation.self_attention import SelfAttention +from fastNLP.modules.decoder.MLP import MLP + + +train_data_path = 'small_train_data.txt' +dev_data_path = 'small_dev_data.txt' +# emb_path = 'glove.txt' + +lstm_hidden_size = 300 +embeding_size = 300 +attention_unit = 350 +attention_hops = 10 +class_num = 5 +nfc = 3000 +### data load ### +train_dataset = Dataset_loader(train_data_path) +train_data = train_dataset.load() + +dev_args = Dataset_loader(dev_data_path) +dev_data = dev_args.load() + +###### preprocess #### +preprocess = Preprocess() +word2index, label2index = preprocess.build_dict(train_data) +train_data, dev_data = preprocess.run(train_data, dev_data) + + + +# emb = EmbedLoader(emb_path) +# embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index) +### construct vocab ### + +class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): + def __init__(self, args=None): + super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() + self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) + self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) + self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) + self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,) + def forward(self,x): + x_emb = self.embedding(x) + output = self.lstm(x_emb) + after_attention, penalty = self.attention(output,x) + after_attention =after_attention.view(after_attention.size(0),-1) + output = self.mlp(after_attention) + return output + + def loss(self, predict, ground_truth): + print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size()))) + print(ground_truth) + return F.cross_entropy(predict, ground_truth) + +train_args = ConfigSection() +ConfigLoader("good path").load_config('config.cfg',{"train": train_args}) +train_args['vocab'] = len(word2index) + + +trainer = ClassificationTrainer(**train_args.data) + +# for k in train_args.__dict__.keys(): +# print(k, train_args[k]) +model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args) +trainer.train(model,train_data , dev_data)