diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 437ab7d2..ac7138e5 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -7,6 +7,7 @@ import torch from fastNLP.action.action import Action from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.tester import Tester +from fastNLP.modules.utils import seq_mask class BaseTrainer(Action): @@ -28,6 +29,7 @@ class BaseTrainer(Action): training parameters """ super(BaseTrainer, self).__init__() + self.train_args = train_args self.n_epochs = train_args.epochs self.validate = train_args.validate self.batch_size = train_args.batch_size @@ -163,8 +165,8 @@ class BaseTrainer(Action): :param data: list. Each entry is a sample, which is also a list of features and label(s). E.g. [ - [[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1 - [[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2 + [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 + [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 ... ] :return batch_x: list. Each entry is a list of features of a sample. @@ -313,6 +315,39 @@ class WordSegTrainer(BaseTrainer): self.optimizer.step() +class POSTrainer(BaseTrainer): + def __init__(self, train_args): + super(POSTrainer, self).__init__(train_args) + self.vocab_size = train_args.vocab_size + self.num_classes = train_args.num_classes + self.max_len = None + self.mask = None + self.batch_x = None + + def prepare_input(self, data_path): + """ + To do: Load pkl files of train/dev/test and embedding + """ + data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) + data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) + return data_train, data_dev + + def data_forward(self, network, x): + seq_len = [len(seq) for seq in x] + x = torch.LongTensor(x) + self.batch_size = x.size(0) + self.max_len = x.size(1) + self.mask = seq_mask(seq_len, self.max_len) + x = network(x) + self.batch_x = x + return x + + def get_loss(self, predict, truth): + truth = torch.LongTensor(truth) + loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) + return loss + + if __name__ == "__name__": train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") trainer = BaseTrainer(train_args) diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 7132eb3b..284be715 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -15,7 +15,6 @@ class POSDatasetLoader(DatasetLoader): def __init__(self, data_name, data_path): super(POSDatasetLoader, self).__init__(data_name, data_path) - #self.data_set = self.load() def load(self): assert os.path.exists(self.data_path) diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index b8d88c35..8b9c6d88 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -46,19 +46,17 @@ class BasePreprocess(object): class POSPreprocess(BasePreprocess): - """ This class are used to preprocess the pos datasets. - In these datasets, each line is divided by '\t' - The first Col is the vocabulary. - The second Col is the labels. + In these datasets, each line are divided by '\t' + while the first Col is the vocabulary and the second + Col is the label. Different sentence are divided by an empty line. e.g: Tom label1 and label2 Jerry label1 . label3 - Hello label4 world label5 ! label3 @@ -71,11 +69,13 @@ class POSPreprocess(BasePreprocess): super(POSPreprocess, self).__init__(data, pickle_path) self.word_dict = None self.label_dict = None + self.data = data + self.pickle_path = pickle_path self.build_dict() self.word2id() - self.id2word() + self.vocab_size = self.id2word() self.class2id() - self.id2class() + self.num_classes = self.id2class() self.embedding() self.data_train() self.data_dev() @@ -87,7 +87,8 @@ class POSPreprocess(BasePreprocess): DEFAULT_RESERVED_LABEL[2]: 4} self.label_dict = {} for w in self.data: - if len(w) == 0: + w = w.strip() + if len(w) <= 1: continue word = w.split('\t') @@ -95,10 +96,11 @@ class POSPreprocess(BasePreprocess): index = len(self.word_dict) self.word_dict[word[0]] = index - for label in word[1: ]: - if label not in self.label_dict: - index = len(self.label_dict) - self.label_dict[label] = index + # for label in word[1: ]: + label = word[1] + if label not in self.label_dict: + index = len(self.label_dict) + self.label_dict[label] = index def pickle_exist(self, pickle_name): """ @@ -107,7 +109,7 @@ class POSPreprocess(BasePreprocess): """ if not os.path.exists(self.pickle_path): os.makedirs(self.pickle_path) - file_name = self.pickle_path + pickle_name + file_name = os.path.join(self.pickle_path, pickle_name) if os.path.exists(file_name): return True else: @@ -118,42 +120,48 @@ class POSPreprocess(BasePreprocess): return # nothing will be done if word2id.pkl exists - file_name = self.pickle_path + "word2id.pkl" - with open(file_name, "wb", encoding='utf-8') as f: + file_name = os.path.join(self.pickle_path, "word2id.pkl") + with open(file_name, "wb") as f: _pickle.dump(self.word_dict, f) def id2word(self): if self.pickle_exist("id2word.pkl"): - return + file_name = os.path.join(self.pickle_path, "id2word.pkl") + id2word_dict = _pickle.load(open(file_name, "rb")) + return len(id2word_dict) # nothing will be done if id2word.pkl exists id2word_dict = {} for word in self.word_dict: id2word_dict[self.word_dict[word]] = word - file_name = self.pickle_path + "id2word.pkl" - with open(file_name, "wb", encoding='utf-8') as f: + file_name = os.path.join(self.pickle_path, "id2word.pkl") + with open(file_name, "wb") as f: _pickle.dump(id2word_dict, f) + return len(id2word_dict) def class2id(self): if self.pickle_exist("class2id.pkl"): return # nothing will be done if class2id.pkl exists - file_name = self.pickle_path + "class2id.pkl" - with open(file_name, "wb", encoding='utf-8') as f: + file_name = os.path.join(self.pickle_path, "class2id.pkl") + with open(file_name, "wb") as f: _pickle.dump(self.label_dict, f) def id2class(self): if self.pickle_exist("id2class.pkl"): - return + file_name = os.path.join(self.pickle_path, "id2class.pkl") + id2class_dict = _pickle.load(open(file_name, "rb")) + return len(id2class_dict) # nothing will be done if id2class.pkl exists id2class_dict = {} for label in self.label_dict: id2class_dict[self.label_dict[label]] = label - file_name = self.pickle_path + "id2class.pkl" - with open(file_name, "wb", encoding='utf-8') as f: + file_name = os.path.join(self.pickle_path, "id2class.pkl") + with open(file_name, "wb") as f: _pickle.dump(id2class_dict, f) + return len(id2class_dict) def embedding(self): if self.pickle_exist("embedding.pkl"): @@ -168,22 +176,26 @@ class POSPreprocess(BasePreprocess): data_train = [] sentence = [] for w in self.data: - if len(w) == 0: + w = w.strip() + if len(w) <= 1: wid = [] lid = [] for i in range(len(sentence)): + # if sentence[i][0]=="": + # print("") wid.append(self.word_dict[sentence[i][0]]) lid.append(self.label_dict[sentence[i][1]]) data_train.append((wid, lid)) sentence = [] + continue sentence.append(w.split('\t')) - file_name = self.pickle_path + "data_train.pkl" - with open(file_name, "wb", encoding='utf-8') as f: + file_name = os.path.join(self.pickle_path, "data_train.pkl") + with open(file_name, "wb") as f: _pickle.dump(data_train, f) def data_dev(self): pass def data_test(self): - pass + pass \ No newline at end of file diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 9249e2e3..54e28687 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -4,9 +4,9 @@ import torch class BaseModel(torch.nn.Module): """Base PyTorch model for all models. Three network modules presented: - - embedding module + - encoder module - aggregation module - - output module + - decoder module Subclasses must implement these three modules with "components". """ @@ -15,21 +15,20 @@ class BaseModel(torch.nn.Module): def forward(self, *inputs): x = self.encode(*inputs) - x = self.aggregation(x) - x = self.output(x) + x = self.aggregate(x) + x = self.decode(x) return x def encode(self, x): raise NotImplementedError - def aggregation(self, x): + def aggregate(self, x): raise NotImplementedError - def output(self, x): + def decode(self, x): raise NotImplementedError - class Vocabulary(object): """A look-up table that allows you to access `Lexeme` objects. The `Vocab` instance also provides access to the `StringStore`, and owns underlying @@ -93,3 +92,4 @@ class Token(object): self.doc = doc self.token = doc[offset] self.i = offset + diff --git a/fastNLP/models/sequencce_modeling.py b/fastNLP/models/sequencce_modeling.py new file mode 100644 index 00000000..af6931e4 --- /dev/null +++ b/fastNLP/models/sequencce_modeling.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F + +from fastNLP.models.base_model import BaseModel +from fastNLP.modules.CRF import ContionalRandomField + + +class SeqLabeling(BaseModel): + """ + PyTorch Network for sequence labeling + """ + + def __init__(self, hidden_dim, + rnn_num_layerd, + num_classes, + vocab_size, + word_emb_dim=100, + init_emb=None, + rnn_mode="gru", + bi_direction=False, + dropout=0.5, + use_crf=True): + super(SeqLabeling, self).__init__() + + self.Emb = nn.Embedding(vocab_size, word_emb_dim) + if init_emb: + self.Emb.weight = nn.Parameter(init_emb) + + self.num_classes = num_classes + self.input_dim = word_emb_dim + self.layers = rnn_num_layerd + self.hidden_dim = hidden_dim + self.bi_direction = bi_direction + self.dropout = dropout + self.mode = rnn_mode + + if self.mode == "lstm": + self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, + bidirectional=self.bi_direction, dropout=self.dropout) + elif self.mode == "gru": + self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, + bidirectional=self.bi_direction, dropout=self.dropout) + elif self.mode == "rnn": + self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, + bidirectional=self.bi_direction, dropout=self.dropout) + else: + raise Exception + if bi_direction: + self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) + else: + self.linear = nn.Linear(self.hidden_dim, self.num_classes) + self.use_crf = use_crf + if self.use_crf: + self.crf = ContionalRandomField(num_classes) + + def forward(self, x): + + x = self.embedding(x) + x, hidden = self.encode(x) + x = self.aggregation(x) + x = self.output(x) + return x + + def embedding(self, x): + return self.Emb(x) + + def encode(self, x): + return self.rnn(x) + + def aggregate(self, x): + return x + + def decode(self, x): + x = self.linear(x) + return x + + def loss(self, x, y, mask, batch_size, max_len): + """ + Negative log likelihood loss. + :param x: + :param y: + :param seq_len: + :return loss: + prediction: + """ + if self.use_crf: + total_loss = self.crf(x, y, mask) + tag_seq = self.crf.viterbi_decode(x, mask) + else: + # error + loss_function = nn.NLLLoss(ignore_index=0, size_average=False) + x = x.view(batch_size * max_len, -1) + score = F.log_softmax(x) + total_loss = loss_function(score, y.view(batch_size * max_len)) + _, tag_seq = torch.max(score) + tag_seq = tag_seq.view(batch_size, max_len) + return torch.mean(total_loss), tag_seq diff --git a/fastNLP/modules/prototype/example.py b/fastNLP/modules/prototype/example.py index a19898c6..d23a0ec2 100644 --- a/fastNLP/modules/prototype/example.py +++ b/fastNLP/modules/prototype/example.py @@ -1,12 +1,13 @@ -import torch -import torch.nn as nn -import encoder +import time + import aggregation +import dataloader import embedding +import encoder import predict +import torch +import torch.nn as nn import torch.optim as optim -import time -import dataloader WORD_NUM = 357361 WORD_SIZE = 100 @@ -16,6 +17,30 @@ R = 10 MLP_HIDDEN = 2000 CLASSES_NUM = 5 +from fastNLP.models.base_model import BaseModel +from fastNLP.action.trainer import BaseTrainer + + +class MyNet(BaseModel): + def __init__(self): + super(MyNet, self).__init__() + self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE) + self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True) + self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R) + self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM) + self.penalty = None + + def encode(self, x): + return self.encode(self.embedding(x)) + + def aggregate(self, x): + x, self.penalty = self.aggregate(x) + return x + + def decode(self, x): + return [self.predict(x), self.penalty] + + class Net(nn.Module): """ A model for sentiment analysis using lstm and self-attention @@ -34,6 +59,19 @@ class Net(nn.Module): x = self.predict(x) return x, penalty + +class MyTrainer(BaseTrainer): + def __init__(self, args): + super(MyTrainer, self).__init__(args) + self.optimizer = None + + def define_optimizer(self): + self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) + + def define_loss(self): + self.loss_func = nn.CrossEntropyLoss() + + def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): """ diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 15afe883..a6b31a20 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -7,3 +7,9 @@ def mask_softmax(matrix, mask): else: raise NotImplementedError return result + + +def seq_mask(seq_len, max_len): + mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] + mask = torch.stack(mask, 1) + return mask diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py new file mode 100644 index 00000000..db4232e7 --- /dev/null +++ b/test/test_POS_pipeline.py @@ -0,0 +1,29 @@ +from fastNLP.action.trainer import POSTrainer +from fastNLP.loader.dataset_loader import POSDatasetLoader +from fastNLP.loader.preprocess import POSPreprocess +from fastNLP.models.sequencce_modeling import SeqLabeling + +data_name = "people" +data_path = "data/people.txt" +pickle_path = "data" + +if __name__ == "__main__": + # Data Loader + pos = POSDatasetLoader(data_name, data_path) + train_data = pos.load_lines() + + # Preprocessor + p = POSPreprocess(train_data, pickle_path) + vocab_size = p.vocab_size + num_classes = p.num_classes + + # Trainer + train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes, + vocab_size=vocab_size, pickle_path=pickle_path) + trainer = POSTrainer(train_args) + + # Model + model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) + + # Start training. + trainer.train(model)