@@ -7,6 +7,7 @@ import torch | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
from fastNLP.action.tester import Tester | from fastNLP.action.tester import Tester | ||||
from fastNLP.modules.utils import seq_mask | |||||
class BaseTrainer(Action): | class BaseTrainer(Action): | ||||
@@ -28,6 +29,7 @@ class BaseTrainer(Action): | |||||
training parameters | training parameters | ||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.train_args = train_args | |||||
self.n_epochs = train_args.epochs | self.n_epochs = train_args.epochs | ||||
self.validate = train_args.validate | self.validate = train_args.validate | ||||
self.batch_size = train_args.batch_size | self.batch_size = train_args.batch_size | ||||
@@ -163,8 +165,8 @@ class BaseTrainer(Action): | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | :param data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
E.g. | E.g. | ||||
[ | [ | ||||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1 | |||||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2 | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | ... | ||||
] | ] | ||||
:return batch_x: list. Each entry is a list of features of a sample. | :return batch_x: list. Each entry is a list of features of a sample. | ||||
@@ -313,6 +315,39 @@ class WordSegTrainer(BaseTrainer): | |||||
self.optimizer.step() | self.optimizer.step() | ||||
class POSTrainer(BaseTrainer): | |||||
def __init__(self, train_args): | |||||
super(POSTrainer, self).__init__(train_args) | |||||
self.vocab_size = train_args.vocab_size | |||||
self.num_classes = train_args.num_classes | |||||
self.max_len = None | |||||
self.mask = None | |||||
self.batch_x = None | |||||
def prepare_input(self, data_path): | |||||
""" | |||||
To do: Load pkl files of train/dev/test and embedding | |||||
""" | |||||
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||||
return data_train, data_dev | |||||
def data_forward(self, network, x): | |||||
seq_len = [len(seq) for seq in x] | |||||
x = torch.LongTensor(x) | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
self.mask = seq_mask(seq_len, self.max_len) | |||||
x = network(x) | |||||
self.batch_x = x | |||||
return x | |||||
def get_loss(self, predict, truth): | |||||
truth = torch.LongTensor(truth) | |||||
loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) | |||||
return loss | |||||
if __name__ == "__name__": | if __name__ == "__name__": | ||||
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") | train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") | ||||
trainer = BaseTrainer(train_args) | trainer = BaseTrainer(train_args) | ||||
@@ -15,7 +15,6 @@ class POSDatasetLoader(DatasetLoader): | |||||
def __init__(self, data_name, data_path): | def __init__(self, data_name, data_path): | ||||
super(POSDatasetLoader, self).__init__(data_name, data_path) | super(POSDatasetLoader, self).__init__(data_name, data_path) | ||||
#self.data_set = self.load() | |||||
def load(self): | def load(self): | ||||
assert os.path.exists(self.data_path) | assert os.path.exists(self.data_path) | ||||
@@ -46,19 +46,17 @@ class BasePreprocess(object): | |||||
class POSPreprocess(BasePreprocess): | class POSPreprocess(BasePreprocess): | ||||
""" | """ | ||||
This class are used to preprocess the pos datasets. | This class are used to preprocess the pos datasets. | ||||
In these datasets, each line is divided by '\t' | |||||
The first Col is the vocabulary. | |||||
The second Col is the labels. | |||||
In these datasets, each line are divided by '\t' | |||||
while the first Col is the vocabulary and the second | |||||
Col is the label. | |||||
Different sentence are divided by an empty line. | Different sentence are divided by an empty line. | ||||
e.g: | e.g: | ||||
Tom label1 | Tom label1 | ||||
and label2 | and label2 | ||||
Jerry label1 | Jerry label1 | ||||
. label3 | . label3 | ||||
Hello label4 | Hello label4 | ||||
world label5 | world label5 | ||||
! label3 | ! label3 | ||||
@@ -71,11 +69,13 @@ class POSPreprocess(BasePreprocess): | |||||
super(POSPreprocess, self).__init__(data, pickle_path) | super(POSPreprocess, self).__init__(data, pickle_path) | ||||
self.word_dict = None | self.word_dict = None | ||||
self.label_dict = None | self.label_dict = None | ||||
self.data = data | |||||
self.pickle_path = pickle_path | |||||
self.build_dict() | self.build_dict() | ||||
self.word2id() | self.word2id() | ||||
self.id2word() | |||||
self.vocab_size = self.id2word() | |||||
self.class2id() | self.class2id() | ||||
self.id2class() | |||||
self.num_classes = self.id2class() | |||||
self.embedding() | self.embedding() | ||||
self.data_train() | self.data_train() | ||||
self.data_dev() | self.data_dev() | ||||
@@ -87,7 +87,8 @@ class POSPreprocess(BasePreprocess): | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | DEFAULT_RESERVED_LABEL[2]: 4} | ||||
self.label_dict = {} | self.label_dict = {} | ||||
for w in self.data: | for w in self.data: | ||||
if len(w) == 0: | |||||
w = w.strip() | |||||
if len(w) <= 1: | |||||
continue | continue | ||||
word = w.split('\t') | word = w.split('\t') | ||||
@@ -95,10 +96,11 @@ class POSPreprocess(BasePreprocess): | |||||
index = len(self.word_dict) | index = len(self.word_dict) | ||||
self.word_dict[word[0]] = index | self.word_dict[word[0]] = index | ||||
for label in word[1: ]: | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
# for label in word[1: ]: | |||||
label = word[1] | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
def pickle_exist(self, pickle_name): | def pickle_exist(self, pickle_name): | ||||
""" | """ | ||||
@@ -107,7 +109,7 @@ class POSPreprocess(BasePreprocess): | |||||
""" | """ | ||||
if not os.path.exists(self.pickle_path): | if not os.path.exists(self.pickle_path): | ||||
os.makedirs(self.pickle_path) | os.makedirs(self.pickle_path) | ||||
file_name = self.pickle_path + pickle_name | |||||
file_name = os.path.join(self.pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | if os.path.exists(file_name): | ||||
return True | return True | ||||
else: | else: | ||||
@@ -118,42 +120,48 @@ class POSPreprocess(BasePreprocess): | |||||
return | return | ||||
# nothing will be done if word2id.pkl exists | # nothing will be done if word2id.pkl exists | ||||
file_name = self.pickle_path + "word2id.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.word_dict, f) | _pickle.dump(self.word_dict, f) | ||||
def id2word(self): | def id2word(self): | ||||
if self.pickle_exist("id2word.pkl"): | if self.pickle_exist("id2word.pkl"): | ||||
return | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
id2word_dict = _pickle.load(open(file_name, "rb")) | |||||
return len(id2word_dict) | |||||
# nothing will be done if id2word.pkl exists | # nothing will be done if id2word.pkl exists | ||||
id2word_dict = {} | id2word_dict = {} | ||||
for word in self.word_dict: | for word in self.word_dict: | ||||
id2word_dict[self.word_dict[word]] = word | id2word_dict[self.word_dict[word]] = word | ||||
file_name = self.pickle_path + "id2word.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2word_dict, f) | _pickle.dump(id2word_dict, f) | ||||
return len(id2word_dict) | |||||
def class2id(self): | def class2id(self): | ||||
if self.pickle_exist("class2id.pkl"): | if self.pickle_exist("class2id.pkl"): | ||||
return | return | ||||
# nothing will be done if class2id.pkl exists | # nothing will be done if class2id.pkl exists | ||||
file_name = self.pickle_path + "class2id.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.label_dict, f) | _pickle.dump(self.label_dict, f) | ||||
def id2class(self): | def id2class(self): | ||||
if self.pickle_exist("id2class.pkl"): | if self.pickle_exist("id2class.pkl"): | ||||
return | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
id2class_dict = _pickle.load(open(file_name, "rb")) | |||||
return len(id2class_dict) | |||||
# nothing will be done if id2class.pkl exists | # nothing will be done if id2class.pkl exists | ||||
id2class_dict = {} | id2class_dict = {} | ||||
for label in self.label_dict: | for label in self.label_dict: | ||||
id2class_dict[self.label_dict[label]] = label | id2class_dict[self.label_dict[label]] = label | ||||
file_name = self.pickle_path + "id2class.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2class_dict, f) | _pickle.dump(id2class_dict, f) | ||||
return len(id2class_dict) | |||||
def embedding(self): | def embedding(self): | ||||
if self.pickle_exist("embedding.pkl"): | if self.pickle_exist("embedding.pkl"): | ||||
@@ -168,22 +176,26 @@ class POSPreprocess(BasePreprocess): | |||||
data_train = [] | data_train = [] | ||||
sentence = [] | sentence = [] | ||||
for w in self.data: | for w in self.data: | ||||
if len(w) == 0: | |||||
w = w.strip() | |||||
if len(w) <= 1: | |||||
wid = [] | wid = [] | ||||
lid = [] | lid = [] | ||||
for i in range(len(sentence)): | for i in range(len(sentence)): | ||||
# if sentence[i][0]=="": | |||||
# print("") | |||||
wid.append(self.word_dict[sentence[i][0]]) | wid.append(self.word_dict[sentence[i][0]]) | ||||
lid.append(self.label_dict[sentence[i][1]]) | lid.append(self.label_dict[sentence[i][1]]) | ||||
data_train.append((wid, lid)) | data_train.append((wid, lid)) | ||||
sentence = [] | sentence = [] | ||||
continue | |||||
sentence.append(w.split('\t')) | sentence.append(w.split('\t')) | ||||
file_name = self.pickle_path + "data_train.pkl" | |||||
with open(file_name, "wb", encoding='utf-8') as f: | |||||
file_name = os.path.join(self.pickle_path, "data_train.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(data_train, f) | _pickle.dump(data_train, f) | ||||
def data_dev(self): | def data_dev(self): | ||||
pass | pass | ||||
def data_test(self): | def data_test(self): | ||||
pass | |||||
pass |
@@ -4,9 +4,9 @@ import torch | |||||
class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
"""Base PyTorch model for all models. | """Base PyTorch model for all models. | ||||
Three network modules presented: | Three network modules presented: | ||||
- embedding module | |||||
- encoder module | |||||
- aggregation module | - aggregation module | ||||
- output module | |||||
- decoder module | |||||
Subclasses must implement these three modules with "components". | Subclasses must implement these three modules with "components". | ||||
""" | """ | ||||
@@ -15,21 +15,20 @@ class BaseModel(torch.nn.Module): | |||||
def forward(self, *inputs): | def forward(self, *inputs): | ||||
x = self.encode(*inputs) | x = self.encode(*inputs) | ||||
x = self.aggregation(x) | |||||
x = self.output(x) | |||||
x = self.aggregate(x) | |||||
x = self.decode(x) | |||||
return x | return x | ||||
def encode(self, x): | def encode(self, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def aggregation(self, x): | |||||
def aggregate(self, x): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def output(self, x): | |||||
def decode(self, x): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class Vocabulary(object): | class Vocabulary(object): | ||||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` | """A look-up table that allows you to access `Lexeme` objects. The `Vocab` | ||||
instance also provides access to the `StringStore`, and owns underlying | instance also provides access to the `StringStore`, and owns underlying | ||||
@@ -93,3 +92,4 @@ class Token(object): | |||||
self.doc = doc | self.doc = doc | ||||
self.token = doc[offset] | self.token = doc[offset] | ||||
self.i = offset | self.i = offset | ||||
@@ -0,0 +1,98 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.CRF import ContionalRandomField | |||||
class SeqLabeling(BaseModel): | |||||
""" | |||||
PyTorch Network for sequence labeling | |||||
""" | |||||
def __init__(self, hidden_dim, | |||||
rnn_num_layerd, | |||||
num_classes, | |||||
vocab_size, | |||||
word_emb_dim=100, | |||||
init_emb=None, | |||||
rnn_mode="gru", | |||||
bi_direction=False, | |||||
dropout=0.5, | |||||
use_crf=True): | |||||
super(SeqLabeling, self).__init__() | |||||
self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||||
if init_emb: | |||||
self.Emb.weight = nn.Parameter(init_emb) | |||||
self.num_classes = num_classes | |||||
self.input_dim = word_emb_dim | |||||
self.layers = rnn_num_layerd | |||||
self.hidden_dim = hidden_dim | |||||
self.bi_direction = bi_direction | |||||
self.dropout = dropout | |||||
self.mode = rnn_mode | |||||
if self.mode == "lstm": | |||||
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "gru": | |||||
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "rnn": | |||||
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
else: | |||||
raise Exception | |||||
if bi_direction: | |||||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||||
else: | |||||
self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||||
self.use_crf = use_crf | |||||
if self.use_crf: | |||||
self.crf = ContionalRandomField(num_classes) | |||||
def forward(self, x): | |||||
x = self.embedding(x) | |||||
x, hidden = self.encode(x) | |||||
x = self.aggregation(x) | |||||
x = self.output(x) | |||||
return x | |||||
def embedding(self, x): | |||||
return self.Emb(x) | |||||
def encode(self, x): | |||||
return self.rnn(x) | |||||
def aggregate(self, x): | |||||
return x | |||||
def decode(self, x): | |||||
x = self.linear(x) | |||||
return x | |||||
def loss(self, x, y, mask, batch_size, max_len): | |||||
""" | |||||
Negative log likelihood loss. | |||||
:param x: | |||||
:param y: | |||||
:param seq_len: | |||||
:return loss: | |||||
prediction: | |||||
""" | |||||
if self.use_crf: | |||||
total_loss = self.crf(x, y, mask) | |||||
tag_seq = self.crf.viterbi_decode(x, mask) | |||||
else: | |||||
# error | |||||
loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||||
x = x.view(batch_size * max_len, -1) | |||||
score = F.log_softmax(x) | |||||
total_loss = loss_function(score, y.view(batch_size * max_len)) | |||||
_, tag_seq = torch.max(score) | |||||
tag_seq = tag_seq.view(batch_size, max_len) | |||||
return torch.mean(total_loss), tag_seq |
@@ -1,12 +1,13 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import encoder | |||||
import time | |||||
import aggregation | import aggregation | ||||
import dataloader | |||||
import embedding | import embedding | ||||
import encoder | |||||
import predict | import predict | ||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | import torch.optim as optim | ||||
import time | |||||
import dataloader | |||||
WORD_NUM = 357361 | WORD_NUM = 357361 | ||||
WORD_SIZE = 100 | WORD_SIZE = 100 | ||||
@@ -16,6 +17,30 @@ R = 10 | |||||
MLP_HIDDEN = 2000 | MLP_HIDDEN = 2000 | ||||
CLASSES_NUM = 5 | CLASSES_NUM = 5 | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.action.trainer import BaseTrainer | |||||
class MyNet(BaseModel): | |||||
def __init__(self): | |||||
super(MyNet, self).__init__() | |||||
self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE) | |||||
self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True) | |||||
self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R) | |||||
self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM) | |||||
self.penalty = None | |||||
def encode(self, x): | |||||
return self.encode(self.embedding(x)) | |||||
def aggregate(self, x): | |||||
x, self.penalty = self.aggregate(x) | |||||
return x | |||||
def decode(self, x): | |||||
return [self.predict(x), self.penalty] | |||||
class Net(nn.Module): | class Net(nn.Module): | ||||
""" | """ | ||||
A model for sentiment analysis using lstm and self-attention | A model for sentiment analysis using lstm and self-attention | ||||
@@ -34,6 +59,19 @@ class Net(nn.Module): | |||||
x = self.predict(x) | x = self.predict(x) | ||||
return x, penalty | return x, penalty | ||||
class MyTrainer(BaseTrainer): | |||||
def __init__(self, args): | |||||
super(MyTrainer, self).__init__(args) | |||||
self.optimizer = None | |||||
def define_optimizer(self): | |||||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||||
def define_loss(self): | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | ||||
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | ||||
""" | """ | ||||
@@ -7,3 +7,9 @@ def mask_softmax(matrix, mask): | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
return result | return result | ||||
def seq_mask(seq_len, max_len): | |||||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||||
mask = torch.stack(mask, 1) | |||||
return mask |
@@ -0,0 +1,29 @@ | |||||
from fastNLP.action.trainer import POSTrainer | |||||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
from fastNLP.loader.preprocess import POSPreprocess | |||||
from fastNLP.models.sequencce_modeling import SeqLabeling | |||||
data_name = "people" | |||||
data_path = "data/people.txt" | |||||
pickle_path = "data" | |||||
if __name__ == "__main__": | |||||
# Data Loader | |||||
pos = POSDatasetLoader(data_name, data_path) | |||||
train_data = pos.load_lines() | |||||
# Preprocessor | |||||
p = POSPreprocess(train_data, pickle_path) | |||||
vocab_size = p.vocab_size | |||||
num_classes = p.num_classes | |||||
# Trainer | |||||
train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes, | |||||
vocab_size=vocab_size, pickle_path=pickle_path) | |||||
trainer = POSTrainer(train_args) | |||||
# Model | |||||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) | |||||
# Start training. | |||||
trainer.train(model) |