Browse Source

finished POSTrainer

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
d7a8217132
8 changed files with 259 additions and 42 deletions
  1. +37
    -2
      fastNLP/action/trainer.py
  2. +0
    -1
      fastNLP/loader/dataset_loader.py
  3. +39
    -27
      fastNLP/loader/preprocess.py
  4. +7
    -7
      fastNLP/models/base_model.py
  5. +98
    -0
      fastNLP/models/sequencce_modeling.py
  6. +43
    -5
      fastNLP/modules/prototype/example.py
  7. +6
    -0
      fastNLP/modules/utils.py
  8. +29
    -0
      test/test_POS_pipeline.py

+ 37
- 2
fastNLP/action/trainer.py View File

@@ -7,6 +7,7 @@ import torch
from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import Tester from fastNLP.action.tester import Tester
from fastNLP.modules.utils import seq_mask




class BaseTrainer(Action): class BaseTrainer(Action):
@@ -28,6 +29,7 @@ class BaseTrainer(Action):
training parameters training parameters
""" """
super(BaseTrainer, self).__init__() super(BaseTrainer, self).__init__()
self.train_args = train_args
self.n_epochs = train_args.epochs self.n_epochs = train_args.epochs
self.validate = train_args.validate self.validate = train_args.validate
self.batch_size = train_args.batch_size self.batch_size = train_args.batch_size
@@ -163,8 +165,8 @@ class BaseTrainer(Action):
:param data: list. Each entry is a sample, which is also a list of features and label(s). :param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g. E.g.
[ [
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
... ...
] ]
:return batch_x: list. Each entry is a list of features of a sample. :return batch_x: list. Each entry is a list of features of a sample.
@@ -313,6 +315,39 @@ class WordSegTrainer(BaseTrainer):
self.optimizer.step() self.optimizer.step()




class POSTrainer(BaseTrainer):
def __init__(self, train_args):
super(POSTrainer, self).__init__(train_args)
self.vocab_size = train_args.vocab_size
self.num_classes = train_args.num_classes
self.max_len = None
self.mask = None
self.batch_x = None

def prepare_input(self, data_path):
"""
To do: Load pkl files of train/dev/test and embedding
"""
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
return data_train, data_dev

def data_forward(self, network, x):
seq_len = [len(seq) for seq in x]
x = torch.LongTensor(x)
self.batch_size = x.size(0)
self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
x = network(x)
self.batch_x = x
return x

def get_loss(self, predict, truth):
truth = torch.LongTensor(truth)
loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len)
return loss


if __name__ == "__name__": if __name__ == "__name__":
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./")
trainer = BaseTrainer(train_args) trainer = BaseTrainer(train_args)


+ 0
- 1
fastNLP/loader/dataset_loader.py View File

@@ -15,7 +15,6 @@ class POSDatasetLoader(DatasetLoader):


def __init__(self, data_name, data_path): def __init__(self, data_name, data_path):
super(POSDatasetLoader, self).__init__(data_name, data_path) super(POSDatasetLoader, self).__init__(data_name, data_path)
#self.data_set = self.load()


def load(self): def load(self):
assert os.path.exists(self.data_path) assert os.path.exists(self.data_path)


+ 39
- 27
fastNLP/loader/preprocess.py View File

@@ -46,19 +46,17 @@ class BasePreprocess(object):
class POSPreprocess(BasePreprocess): class POSPreprocess(BasePreprocess):
""" """
This class are used to preprocess the pos datasets. This class are used to preprocess the pos datasets.
In these datasets, each line is divided by '\t'
The first Col is the vocabulary.
The second Col is the labels.
In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line. Different sentence are divided by an empty line.
e.g: e.g:
Tom label1 Tom label1
and label2 and label2
Jerry label1 Jerry label1
. label3 . label3
Hello label4 Hello label4
world label5 world label5
! label3 ! label3
@@ -71,11 +69,13 @@ class POSPreprocess(BasePreprocess):
super(POSPreprocess, self).__init__(data, pickle_path) super(POSPreprocess, self).__init__(data, pickle_path)
self.word_dict = None self.word_dict = None
self.label_dict = None self.label_dict = None
self.data = data
self.pickle_path = pickle_path
self.build_dict() self.build_dict()
self.word2id() self.word2id()
self.id2word()
self.vocab_size = self.id2word()
self.class2id() self.class2id()
self.id2class()
self.num_classes = self.id2class()
self.embedding() self.embedding()
self.data_train() self.data_train()
self.data_dev() self.data_dev()
@@ -87,7 +87,8 @@ class POSPreprocess(BasePreprocess):
DEFAULT_RESERVED_LABEL[2]: 4} DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = {} self.label_dict = {}
for w in self.data: for w in self.data:
if len(w) == 0:
w = w.strip()
if len(w) <= 1:
continue continue
word = w.split('\t') word = w.split('\t')
@@ -95,10 +96,11 @@ class POSPreprocess(BasePreprocess):
index = len(self.word_dict) index = len(self.word_dict)
self.word_dict[word[0]] = index self.word_dict[word[0]] = index
for label in word[1: ]:
if label not in self.label_dict:
index = len(self.label_dict)
self.label_dict[label] = index
# for label in word[1: ]:
label = word[1]
if label not in self.label_dict:
index = len(self.label_dict)
self.label_dict[label] = index
def pickle_exist(self, pickle_name): def pickle_exist(self, pickle_name):
""" """
@@ -107,7 +109,7 @@ class POSPreprocess(BasePreprocess):
""" """
if not os.path.exists(self.pickle_path): if not os.path.exists(self.pickle_path):
os.makedirs(self.pickle_path) os.makedirs(self.pickle_path)
file_name = self.pickle_path + pickle_name
file_name = os.path.join(self.pickle_path, pickle_name)
if os.path.exists(file_name): if os.path.exists(file_name):
return True return True
else: else:
@@ -118,42 +120,48 @@ class POSPreprocess(BasePreprocess):
return return
# nothing will be done if word2id.pkl exists # nothing will be done if word2id.pkl exists
file_name = self.pickle_path + "word2id.pkl"
with open(file_name, "wb", encoding='utf-8') as f:
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f) _pickle.dump(self.word_dict, f)
def id2word(self): def id2word(self):
if self.pickle_exist("id2word.pkl"): if self.pickle_exist("id2word.pkl"):
return
file_name = os.path.join(self.pickle_path, "id2word.pkl")
id2word_dict = _pickle.load(open(file_name, "rb"))
return len(id2word_dict)
# nothing will be done if id2word.pkl exists # nothing will be done if id2word.pkl exists
id2word_dict = {} id2word_dict = {}
for word in self.word_dict: for word in self.word_dict:
id2word_dict[self.word_dict[word]] = word id2word_dict[self.word_dict[word]] = word
file_name = self.pickle_path + "id2word.pkl"
with open(file_name, "wb", encoding='utf-8') as f:
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2word_dict, f) _pickle.dump(id2word_dict, f)
return len(id2word_dict)
def class2id(self): def class2id(self):
if self.pickle_exist("class2id.pkl"): if self.pickle_exist("class2id.pkl"):
return return
# nothing will be done if class2id.pkl exists # nothing will be done if class2id.pkl exists
file_name = self.pickle_path + "class2id.pkl"
with open(file_name, "wb", encoding='utf-8') as f:
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.label_dict, f) _pickle.dump(self.label_dict, f)
def id2class(self): def id2class(self):
if self.pickle_exist("id2class.pkl"): if self.pickle_exist("id2class.pkl"):
return
file_name = os.path.join(self.pickle_path, "id2class.pkl")
id2class_dict = _pickle.load(open(file_name, "rb"))
return len(id2class_dict)
# nothing will be done if id2class.pkl exists # nothing will be done if id2class.pkl exists
id2class_dict = {} id2class_dict = {}
for label in self.label_dict: for label in self.label_dict:
id2class_dict[self.label_dict[label]] = label id2class_dict[self.label_dict[label]] = label
file_name = self.pickle_path + "id2class.pkl"
with open(file_name, "wb", encoding='utf-8') as f:
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2class_dict, f) _pickle.dump(id2class_dict, f)
return len(id2class_dict)
def embedding(self): def embedding(self):
if self.pickle_exist("embedding.pkl"): if self.pickle_exist("embedding.pkl"):
@@ -168,22 +176,26 @@ class POSPreprocess(BasePreprocess):
data_train = [] data_train = []
sentence = [] sentence = []
for w in self.data: for w in self.data:
if len(w) == 0:
w = w.strip()
if len(w) <= 1:
wid = [] wid = []
lid = [] lid = []
for i in range(len(sentence)): for i in range(len(sentence)):
# if sentence[i][0]=="":
# print("")
wid.append(self.word_dict[sentence[i][0]]) wid.append(self.word_dict[sentence[i][0]])
lid.append(self.label_dict[sentence[i][1]]) lid.append(self.label_dict[sentence[i][1]])
data_train.append((wid, lid)) data_train.append((wid, lid))
sentence = [] sentence = []
continue
sentence.append(w.split('\t')) sentence.append(w.split('\t'))
file_name = self.pickle_path + "data_train.pkl"
with open(file_name, "wb", encoding='utf-8') as f:
file_name = os.path.join(self.pickle_path, "data_train.pkl")
with open(file_name, "wb") as f:
_pickle.dump(data_train, f) _pickle.dump(data_train, f)
def data_dev(self): def data_dev(self):
pass pass
def data_test(self): def data_test(self):
pass
pass

+ 7
- 7
fastNLP/models/base_model.py View File

@@ -4,9 +4,9 @@ import torch
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models. """Base PyTorch model for all models.
Three network modules presented: Three network modules presented:
- embedding module
- encoder module
- aggregation module - aggregation module
- output module
- decoder module
Subclasses must implement these three modules with "components". Subclasses must implement these three modules with "components".
""" """


@@ -15,21 +15,20 @@ class BaseModel(torch.nn.Module):


def forward(self, *inputs): def forward(self, *inputs):
x = self.encode(*inputs) x = self.encode(*inputs)
x = self.aggregation(x)
x = self.output(x)
x = self.aggregate(x)
x = self.decode(x)
return x return x


def encode(self, x): def encode(self, x):
raise NotImplementedError raise NotImplementedError


def aggregation(self, x):
def aggregate(self, x):
raise NotImplementedError raise NotImplementedError


def output(self, x):
def decode(self, x):
raise NotImplementedError raise NotImplementedError





class Vocabulary(object): class Vocabulary(object):
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` """A look-up table that allows you to access `Lexeme` objects. The `Vocab`
instance also provides access to the `StringStore`, and owns underlying instance also provides access to the `StringStore`, and owns underlying
@@ -93,3 +92,4 @@ class Token(object):
self.doc = doc self.doc = doc
self.token = doc[offset] self.token = doc[offset]
self.i = offset self.i = offset


+ 98
- 0
fastNLP/models/sequencce_modeling.py View File

@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from torch.nn import functional as F

from fastNLP.models.base_model import BaseModel
from fastNLP.modules.CRF import ContionalRandomField


class SeqLabeling(BaseModel):
"""
PyTorch Network for sequence labeling
"""

def __init__(self, hidden_dim,
rnn_num_layerd,
num_classes,
vocab_size,
word_emb_dim=100,
init_emb=None,
rnn_mode="gru",
bi_direction=False,
dropout=0.5,
use_crf=True):
super(SeqLabeling, self).__init__()

self.Emb = nn.Embedding(vocab_size, word_emb_dim)
if init_emb:
self.Emb.weight = nn.Parameter(init_emb)

self.num_classes = num_classes
self.input_dim = word_emb_dim
self.layers = rnn_num_layerd
self.hidden_dim = hidden_dim
self.bi_direction = bi_direction
self.dropout = dropout
self.mode = rnn_mode

if self.mode == "lstm":
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "gru":
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "rnn":
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
else:
raise Exception
if bi_direction:
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
else:
self.linear = nn.Linear(self.hidden_dim, self.num_classes)
self.use_crf = use_crf
if self.use_crf:
self.crf = ContionalRandomField(num_classes)

def forward(self, x):

x = self.embedding(x)
x, hidden = self.encode(x)
x = self.aggregation(x)
x = self.output(x)
return x

def embedding(self, x):
return self.Emb(x)

def encode(self, x):
return self.rnn(x)

def aggregate(self, x):
return x

def decode(self, x):
x = self.linear(x)
return x

def loss(self, x, y, mask, batch_size, max_len):
"""
Negative log likelihood loss.
:param x:
:param y:
:param seq_len:
:return loss:
prediction:
"""
if self.use_crf:
total_loss = self.crf(x, y, mask)
tag_seq = self.crf.viterbi_decode(x, mask)
else:
# error
loss_function = nn.NLLLoss(ignore_index=0, size_average=False)
x = x.view(batch_size * max_len, -1)
score = F.log_softmax(x)
total_loss = loss_function(score, y.view(batch_size * max_len))
_, tag_seq = torch.max(score)
tag_seq = tag_seq.view(batch_size, max_len)
return torch.mean(total_loss), tag_seq

+ 43
- 5
fastNLP/modules/prototype/example.py View File

@@ -1,12 +1,13 @@
import torch
import torch.nn as nn
import encoder
import time

import aggregation import aggregation
import dataloader
import embedding import embedding
import encoder
import predict import predict
import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import time
import dataloader


WORD_NUM = 357361 WORD_NUM = 357361
WORD_SIZE = 100 WORD_SIZE = 100
@@ -16,6 +17,30 @@ R = 10
MLP_HIDDEN = 2000 MLP_HIDDEN = 2000
CLASSES_NUM = 5 CLASSES_NUM = 5


from fastNLP.models.base_model import BaseModel
from fastNLP.action.trainer import BaseTrainer


class MyNet(BaseModel):
def __init__(self):
super(MyNet, self).__init__()
self.embedding = embedding.Lookuptable(WORD_NUM, WORD_SIZE)
self.encoder = encoder.Lstm(WORD_SIZE, HIDDEN_SIZE, 1, 0.5, True)
self.aggregation = aggregation.Selfattention(2 * HIDDEN_SIZE, D_A, R)
self.predict = predict.MLP(R * HIDDEN_SIZE * 2, MLP_HIDDEN, CLASSES_NUM)
self.penalty = None

def encode(self, x):
return self.encode(self.embedding(x))

def aggregate(self, x):
x, self.penalty = self.aggregate(x)
return x

def decode(self, x):
return [self.predict(x), self.penalty]


class Net(nn.Module): class Net(nn.Module):
""" """
A model for sentiment analysis using lstm and self-attention A model for sentiment analysis using lstm and self-attention
@@ -34,6 +59,19 @@ class Net(nn.Module):
x = self.predict(x) x = self.predict(x)
return x, penalty return x, penalty



class MyTrainer(BaseTrainer):
def __init__(self, args):
super(MyTrainer, self).__init__(args)
self.optimizer = None

def define_optimizer(self):
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)

def define_loss(self):
self.loss_func = nn.CrossEntropyLoss()


def train(model_dict=None, using_cuda=True, learning_rate=0.06,\ def train(model_dict=None, using_cuda=True, learning_rate=0.06,\
momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10): momentum=0.3, batch_size=32, epochs=5, coef=1.0, interval=10):
""" """


+ 6
- 0
fastNLP/modules/utils.py View File

@@ -7,3 +7,9 @@ def mask_softmax(matrix, mask):
else: else:
raise NotImplementedError raise NotImplementedError
return result return result


def seq_mask(seq_len, max_len):
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)]
mask = torch.stack(mask, 1)
return mask

+ 29
- 0
test/test_POS_pipeline.py View File

@@ -0,0 +1,29 @@
from fastNLP.action.trainer import POSTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
from fastNLP.models.sequencce_modeling import SeqLabeling

data_name = "people"
data_path = "data/people.txt"
pickle_path = "data"

if __name__ == "__main__":
# Data Loader
pos = POSDatasetLoader(data_name, data_path)
train_data = pos.load_lines()

# Preprocessor
p = POSPreprocess(train_data, pickle_path)
vocab_size = p.vocab_size
num_classes = p.num_classes

# Trainer
train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes,
vocab_size=vocab_size, pickle_path=pickle_path)
trainer = POSTrainer(train_args)

# Model
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)

# Start training.
trainer.train(model)

Loading…
Cancel
Save