@@ -7,6 +7,7 @@ import torch | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
from fastNLP.action.tester import Tester | |||
from fastNLP.modules.utils import seq_mask | |||
class BaseTrainer(Action): | |||
@@ -28,6 +29,7 @@ class BaseTrainer(Action): | |||
training parameters | |||
""" | |||
super(BaseTrainer, self).__init__() | |||
self.train_args = train_args | |||
self.n_epochs = train_args.epochs | |||
self.validate = train_args.validate | |||
self.batch_size = train_args.batch_size | |||
@@ -163,8 +165,8 @@ class BaseTrainer(Action): | |||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
E.g. | |||
[ | |||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1 | |||
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2 | |||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:return batch_x: list. Each entry is a list of features of a sample. | |||
@@ -313,6 +315,39 @@ class WordSegTrainer(BaseTrainer): | |||
self.optimizer.step() | |||
class POSTrainer(BaseTrainer): | |||
def __init__(self, train_args): | |||
super(POSTrainer, self).__init__(train_args) | |||
self.vocab_size = train_args.vocab_size | |||
self.num_classes = train_args.num_classes | |||
self.max_len = None | |||
self.mask = None | |||
self.batch_x = None | |||
def prepare_input(self, data_path): | |||
""" | |||
To do: Load pkl files of train/dev/test and embedding | |||
""" | |||
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
return data_train, data_dev | |||
def data_forward(self, network, x): | |||
seq_len = [len(seq) for seq in x] | |||
x = torch.LongTensor(x) | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
self.mask = seq_mask(seq_len, self.max_len) | |||
x = network(x) | |||
self.batch_x = x | |||
return x | |||
def get_loss(self, predict, truth): | |||
truth = torch.LongTensor(truth) | |||
loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) | |||
return loss | |||
if __name__ == "__name__": | |||
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") | |||
trainer = BaseTrainer(train_args) | |||
@@ -15,7 +15,6 @@ class POSDatasetLoader(DatasetLoader): | |||
def __init__(self, data_name, data_path): | |||
super(POSDatasetLoader, self).__init__(data_name, data_path) | |||
#self.data_set = self.load() | |||
def load(self): | |||
assert os.path.exists(self.data_path) | |||
@@ -46,19 +46,17 @@ class BasePreprocess(object): | |||
class POSPreprocess(BasePreprocess): | |||
""" | |||
This class are used to preprocess the pos datasets. | |||
In these datasets, each line is divided by '\t' | |||
The first Col is the vocabulary. | |||
The second Col is the labels. | |||
In these datasets, each line are divided by '\t' | |||
while the first Col is the vocabulary and the second | |||
Col is the label. | |||
Different sentence are divided by an empty line. | |||
e.g: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
@@ -71,11 +69,13 @@ class POSPreprocess(BasePreprocess): | |||
super(POSPreprocess, self).__init__(data, pickle_path) | |||
self.word_dict = None | |||
self.label_dict = None | |||
self.data = data | |||
self.pickle_path = pickle_path | |||
self.build_dict() | |||
self.word2id() | |||
self.id2word() | |||
self.vocab_size = self.id2word() | |||
self.class2id() | |||
self.id2class() | |||
self.num_classes = self.id2class() | |||
self.embedding() | |||
self.data_train() | |||
self.data_dev() | |||
@@ -87,7 +87,8 @@ class POSPreprocess(BasePreprocess): | |||
DEFAULT_RESERVED_LABEL[2]: 4} | |||
self.label_dict = {} | |||
for w in self.data: | |||
if len(w) == 0: | |||
w = w.strip() | |||
if len(w) <= 1: | |||
continue | |||
word = w.split('\t') | |||
@@ -95,10 +96,11 @@ class POSPreprocess(BasePreprocess): | |||
index = len(self.word_dict) | |||
self.word_dict[word[0]] = index | |||
for label in word[1: ]: | |||
if label not in self.label_dict: | |||
index = len(self.label_dict) | |||
self.label_dict[label] = index | |||
# for label in word[1: ]: | |||
label = word[1] | |||
if label not in self.label_dict: | |||
index = len(self.label_dict) | |||
self.label_dict[label] = index | |||
def pickle_exist(self, pickle_name): | |||
""" | |||
@@ -107,7 +109,7 @@ class POSPreprocess(BasePreprocess): | |||
""" | |||
if not os.path.exists(self.pickle_path): | |||
os.makedirs(self.pickle_path) | |||
file_name = self.pickle_path + pickle_name | |||
file_name = os.path.join(self.pickle_path, pickle_name) | |||
if os.path.exists(file_name): | |||
return True | |||
else: | |||
@@ -118,42 +120,48 @@ class POSPreprocess(BasePreprocess): | |||
return | |||
# nothing will be done if word2id.pkl exists | |||
file_name = self.pickle_path + "word2id.pkl" | |||
with open(file_name, "wb", encoding='utf-8') as f: | |||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(self.word_dict, f) | |||
def id2word(self): | |||
if self.pickle_exist("id2word.pkl"): | |||
return | |||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
id2word_dict = _pickle.load(open(file_name, "rb")) | |||
return len(id2word_dict) | |||
# nothing will be done if id2word.pkl exists | |||
id2word_dict = {} | |||
for word in self.word_dict: | |||
id2word_dict[self.word_dict[word]] = word | |||
file_name = self.pickle_path + "id2word.pkl" | |||
with open(file_name, "wb", encoding='utf-8') as f: | |||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(id2word_dict, f) | |||
return len(id2word_dict) | |||
def class2id(self): | |||
if self.pickle_exist("class2id.pkl"): | |||
return | |||
# nothing will be done if class2id.pkl exists | |||
file_name = self.pickle_path + "class2id.pkl" | |||
with open(file_name, "wb", encoding='utf-8') as f: | |||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(self.label_dict, f) | |||
def id2class(self): | |||
if self.pickle_exist("id2class.pkl"): | |||
return | |||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
id2class_dict = _pickle.load(open(file_name, "rb")) | |||
return len(id2class_dict) | |||
# nothing will be done if id2class.pkl exists | |||
id2class_dict = {} | |||
for label in self.label_dict: | |||
id2class_dict[self.label_dict[label]] = label | |||
file_name = self.pickle_path + "id2class.pkl" | |||
with open(file_name, "wb", encoding='utf-8') as f: | |||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(id2class_dict, f) | |||
return len(id2class_dict) | |||
def embedding(self): | |||
if self.pickle_exist("embedding.pkl"): | |||
@@ -168,22 +176,26 @@ class POSPreprocess(BasePreprocess): | |||
data_train = [] | |||
sentence = [] | |||
for w in self.data: | |||
if len(w) == 0: | |||
w = w.strip() | |||
if len(w) <= 1: | |||
wid = [] | |||
lid = [] | |||
for i in range(len(sentence)): | |||
# if sentence[i][0]=="": | |||
# print("") | |||
wid.append(self.word_dict[sentence[i][0]]) | |||
lid.append(self.label_dict[sentence[i][1]]) | |||
data_train.append((wid, lid)) | |||
sentence = [] | |||
continue | |||
sentence.append(w.split('\t')) | |||
file_name = self.pickle_path + "data_train.pkl" | |||
with open(file_name, "wb", encoding='utf-8') as f: | |||
file_name = os.path.join(self.pickle_path, "data_train.pkl") | |||
with open(file_name, "wb") as f: | |||
_pickle.dump(data_train, f) | |||
def data_dev(self): | |||
pass | |||
def data_test(self): | |||
pass | |||
pass |
@@ -4,9 +4,9 @@ import torch | |||
class BaseModel(torch.nn.Module): | |||
"""Base PyTorch model for all models. | |||
Three network modules presented: | |||
- embedding module | |||
- encoder module | |||
- aggregation module | |||
- output module | |||
- decoder module | |||
Subclasses must implement these three modules with "components". | |||
""" | |||
@@ -15,21 +15,20 @@ class BaseModel(torch.nn.Module): | |||
def forward(self, *inputs): | |||
x = self.encode(*inputs) | |||
x = self.aggregation(x) | |||
x = self.output(x) | |||
x = self.aggregate(x) | |||
x = self.decode(x) | |||
return x | |||
def encode(self, x): | |||
raise NotImplementedError | |||
def aggregation(self, x): | |||
def aggregate(self, x): | |||
raise NotImplementedError | |||
def output(self, x): | |||
def decode(self, x): | |||
raise NotImplementedError | |||
class Vocabulary(object): | |||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` | |||
instance also provides access to the `StringStore`, and owns underlying | |||
@@ -93,3 +92,4 @@ class Token(object): | |||
self.doc = doc | |||
self.token = doc[offset] | |||
self.i = offset | |||
@@ -0,0 +1,98 @@ | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn import functional as F | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.CRF import ContionalRandomField | |||
class SeqLabeling(BaseModel): | |||
""" | |||
PyTorch Network for sequence labeling | |||
""" | |||
def __init__(self, hidden_dim, | |||
rnn_num_layerd, | |||
num_classes, | |||
vocab_size, | |||
word_emb_dim=100, | |||
init_emb=None, | |||
rnn_mode="gru", | |||
bi_direction=False, | |||
dropout=0.5, | |||
use_crf=True): | |||
super(SeqLabeling, self).__init__() | |||
self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||
if init_emb: | |||
self.Emb.weight = nn.Parameter(init_emb) | |||
self.num_classes = num_classes | |||
self.input_dim = word_emb_dim | |||
self.layers = rnn_num_layerd | |||
self.hidden_dim = hidden_dim | |||
self.bi_direction = bi_direction | |||
self.dropout = dropout | |||
self.mode = rnn_mode | |||
if self.mode == "lstm": | |||
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
elif self.mode == "gru": | |||
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
elif self.mode == "rnn": | |||
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
else: | |||
raise Exception | |||
if bi_direction: | |||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||
else: | |||
self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||
self.use_crf = use_crf | |||
if self.use_crf: | |||
self.crf = ContionalRandomField(num_classes) | |||
def forward(self, x): | |||
x = self.embedding(x) | |||
x, hidden = self.encode(x) | |||
x = self.aggregation(x) | |||
x = self.output(x) | |||
return x | |||
def embedding(self, x): | |||
return self.Emb(x) | |||
def encode(self, x): | |||
return self.rnn(x) | |||
def aggregate(self, x): | |||
return x | |||
def decode(self, x): | |||
x = self.linear(x) | |||
return x | |||
def loss(self, x, y, mask, batch_size, max_len): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: | |||
:param y: | |||
:param seq_len: | |||
:return loss: | |||
prediction: | |||
""" | |||
if self.use_crf: | |||
total_loss = self.crf(x, y, mask) | |||
tag_seq = self.crf.viterbi_decode(x, mask) | |||
else: | |||
# error | |||
loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||
x = x.view(batch_size * max_len, -1) | |||
score = F.log_softmax(x) | |||
total_loss = loss_function(score, y.view(batch_size * max_len)) | |||
_, tag_seq = torch.max(score) | |||
tag_seq = tag_seq.view(batch_size, max_len) | |||
return torch.mean(total_loss), tag_seq |
@@ -1,12 +1,13 @@ | |||
import torch | |||
import torch.nn as nn | |||
import encoder | |||
import time | |||
import aggregation | |||
import dataloader | |||
import embedding | |||
import encoder | |||
import predict | |||
import torch | |||
import torch.nn as nn | |||
import torch.optim as optim | |||
import time | |||
import dataloader | |||
WORD_NUM = 357361 | |||
WORD_SIZE = 100 | |||
@@ -16,6 +17,30 @@ R = 10 | |||
MLP_HIDDEN = 2000 | |||
CLASSES_NUM = 5 | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.action.trainer import BaseTrainer | |||
class MyNet(BaseModel): | |||
def __init__(self): | |||
super(MyNet, self).__init__() | |||
self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE) | |||
self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True) | |||
self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R) | |||
self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM) | |||
self.penalty = None | |||
def encode(self, x): | |||
return self.encode(self.embedding(x)) | |||
def aggregate(self, x): | |||
x, self.penalty = self.aggregate(x) | |||
return x | |||
def decode(self, x): | |||
return [self.predict(x), self.penalty] | |||
class Net(nn.Module): | |||
""" | |||
A model for sentiment analysis using lstm and self-attention | |||
@@ -34,6 +59,19 @@ class Net(nn.Module): | |||
x = self.predict(x) | |||
return x, penalty | |||
class MyTrainer(BaseTrainer): | |||
def __init__(self, args): | |||
super(MyTrainer, self).__init__(args) | |||
self.optimizer = None | |||
def define_optimizer(self): | |||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||
def define_loss(self): | |||
self.loss_func = nn.CrossEntropyLoss() | |||
def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ | |||
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): | |||
""" | |||
@@ -7,3 +7,9 @@ def mask_softmax(matrix, mask): | |||
else: | |||
raise NotImplementedError | |||
return result | |||
def seq_mask(seq_len, max_len): | |||
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||
mask = torch.stack(mask, 1) | |||
return mask |
@@ -0,0 +1,29 @@ | |||
from fastNLP.action.trainer import POSTrainer | |||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
from fastNLP.loader.preprocess import POSPreprocess | |||
from fastNLP.models.sequencce_modeling import SeqLabeling | |||
data_name = "people" | |||
data_path = "data/people.txt" | |||
pickle_path = "data" | |||
if __name__ == "__main__": | |||
# Data Loader | |||
pos = POSDatasetLoader(data_name, data_path) | |||
train_data = pos.load_lines() | |||
# Preprocessor | |||
p = POSPreprocess(train_data, pickle_path) | |||
vocab_size = p.vocab_size | |||
num_classes = p.num_classes | |||
# Trainer | |||
train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes, | |||
vocab_size=vocab_size, pickle_path=pickle_path) | |||
trainer = POSTrainer(train_args) | |||
# Model | |||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) | |||
# Start training. | |||
trainer.train(model) |