@@ -27,7 +27,7 @@ class Action(object): | |||||
:return iteration:int, the number of step in each epoch | :return iteration:int, the number of step in each epoch | ||||
generator:generator, to generate batch inputs | generator:generator, to generate batch inputs | ||||
""" | """ | ||||
n_samples = X.size()[0] | |||||
n_samples = X.shape[0] | |||||
num_iter = n_samples // batch_size | num_iter = n_samples // batch_size | ||||
if Y is None: | if Y is None: | ||||
generator = self._batch_generate(batch_size, num_iter, X) | generator = self._batch_generate(batch_size, num_iter, X) | ||||
@@ -6,7 +6,7 @@ from .tester import Tester | |||||
class Trainer(Action): | class Trainer(Action): | ||||
""" | """ | ||||
Trainer for common training logic of all models | |||||
Trainer is a common training pipeline shared among all models. | |||||
""" | """ | ||||
TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better", | TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better", | ||||
"log_per_step", "log_validation", "batch_size"]) | "log_per_step", "log_validation", "batch_size"]) | ||||
@@ -23,12 +23,12 @@ class Trainer(Action): | |||||
self.log_validation = train_args.log_validation | self.log_validation = train_args.log_validation | ||||
self.batch_size = train_args.batch_size | self.batch_size = train_args.batch_size | ||||
def train(self, network, train_data, dev_data): | |||||
def train(self, network, train_data, dev_data=None): | |||||
""" | """ | ||||
:param network: the model controller | :param network: the model controller | ||||
:param train_data: raw data for training | :param train_data: raw data for training | ||||
:param dev_data: raw data for validation | :param dev_data: raw data for validation | ||||
:return: | |||||
This method will call all the base methods of network (implemented in model.base_model). | |||||
""" | """ | ||||
train_x, train_y = network.prepare_input(train_data) | train_x, train_y = network.prepare_input(train_data) | ||||
@@ -60,6 +60,8 @@ class Trainer(Action): | |||||
#################### evaluate over dev set ################### | #################### evaluate over dev set ################### | ||||
if self.validate: | if self.validate: | ||||
if dev_data is None: | |||||
raise RuntimeError("No validation data provided.") | |||||
# give all controls to tester | # give all controls to tester | ||||
evaluator.test(network, dev_data) | evaluator.test(network, dev_data) | ||||
@@ -14,6 +14,11 @@ class BaseLoader(object): | |||||
text = f.read() | text = f.read() | ||||
return text | return text | ||||
def load_lines(self): | |||||
with open(self.data_path, "r", encoding="utf=8") as f: | |||||
text = f.readlines() | |||||
return text | |||||
class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
""" | """ | ||||
@@ -29,3 +34,4 @@ class ToyLoader0(BaseLoader): | |||||
import re | import re | ||||
corpus = re.sub(r"<unk>", "unk", corpus) | corpus = re.sub(r"<unk>", "unk", corpus) | ||||
return corpus.split() | return corpus.split() | ||||
@@ -2,29 +2,64 @@ import numpy as np | |||||
class BaseModel(object): | class BaseModel(object): | ||||
"""PyTorch base model for all models""" | |||||
"""The base class of all models. | |||||
This class and its subclasses are actually "wrappers" of the PyTorch models. | |||||
They act as an interface between Trainer and the deep learning networks. | |||||
This interface provides the following methods to be called by Trainer. | |||||
- prepare_input | |||||
- mode | |||||
- define_optimizer | |||||
- data_forward | |||||
- grad_backward | |||||
- get_loss | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
""" | """ | ||||
:param data: str, raw input vector(?) | |||||
Perform data transformation from raw input to vector/matrix inputs. | |||||
:param data: raw inputs | |||||
:return (X, Y): tuple, input features and labels | :return (X, Y): tuple, input features and labels | ||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def mode(self, test=False): | def mode(self, test=False): | ||||
""" | |||||
Tell the network to be trained or not, required by PyTorch. | |||||
:param test: bool | |||||
""" | |||||
raise NotImplementedError | |||||
def define_optimizer(self): | |||||
""" | |||||
Define PyTorch optimizer specified by the model. | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def data_forward(self, *x): | def data_forward(self, *x): | ||||
""" | |||||
Forward pass of the data. | |||||
:param x: input feature matrix and label vector | |||||
:return: output by the model | |||||
""" | |||||
# required by PyTorch nn | # required by PyTorch nn | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def grad_backward(self): | def grad_backward(self): | ||||
""" | |||||
Perform gradient descent to update the model parameters. | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def get_loss(self, pred, truth): | def get_loss(self, pred, truth): | ||||
""" | |||||
Compute loss given model prediction and ground truth. Loss function specified by the model. | |||||
:param pred: prediction label vector | |||||
:param truth: ground truth label vector | |||||
:return: a scalar | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -54,29 +89,70 @@ class ToyModel(BaseModel): | |||||
self._loss = np.mean(np.square(pred - truth)) | self._loss = np.mean(np.square(pred - truth)) | ||||
return self._loss | return self._loss | ||||
def define_optimizer(self): | |||||
pass | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
""" | |||||
A collection of lookup tables. | |||||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` | |||||
instance also provides access to the `StringStore`, and owns underlying | |||||
data that is shared between `Doc` objects. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.word_set = None | |||||
self.word2idx = None | |||||
self.emb_matrix = None | |||||
def lookup(self, word): | |||||
if word in self.word_set: | |||||
return self.emb_matrix[self.word2idx[word]] | |||||
return LookupError("The key " + word + " does not exist.") | |||||
"""Create the vocabulary. | |||||
RETURNS (Vocab): The newly constructed object. | |||||
""" | |||||
self.data_frame = None | |||||
class Document(object): | class Document(object): | ||||
"""A sequence of Token objects. Access sentences and named entities, export | |||||
annotations to numpy arrays, losslessly serialize to compressed binary | |||||
strings. The `Doc` object holds an array of `Token` objects. The | |||||
Python-level `Token` and `Span` objects are views of this array, i.e. | |||||
they don't own the data themselves. -- spacy | |||||
""" | """ | ||||
contains a sequence of tokens | |||||
each token is a character with linguistic attributes | |||||
def __init__(self, vocab, words=None, spaces=None): | |||||
"""Create a Doc object. | |||||
vocab (Vocab): A vocabulary object, which must match any models you | |||||
want to use (e.g. tokenizer, parser, entity recognizer). | |||||
words (list or None): A list of unicode strings, to add to the document | |||||
as words. If `None`, defaults to empty list. | |||||
spaces (list or None): A list of boolean values, of the same length as | |||||
words. True means that the word is followed by a space, False means | |||||
it is not. If `None`, defaults to `[True]*len(words)` | |||||
user_data (dict or None): Optional extra data to attach to the Doc. | |||||
RETURNS (Doc): The newly constructed object. | |||||
""" | |||||
self.vocab = vocab | |||||
self.spaces = spaces | |||||
self.words = words | |||||
if spaces is None: | |||||
self.spaces = [True] * len(self.words) | |||||
elif len(spaces) != len(self.words): | |||||
raise ValueError("dismatch spaces and words") | |||||
def get_chunker(self, vocab): | |||||
return None | |||||
def push_back(self, vocab): | |||||
pass | |||||
class Token(object): | |||||
"""An individual token – i.e. a word, punctuation symbol, whitespace, | |||||
etc. | |||||
""" | """ | ||||
def __init__(self): | |||||
# wrap pandas.dataframe | |||||
self.dataframe = None | |||||
def __init__(self, vocab, doc, offset): | |||||
"""Construct a `Token` object. | |||||
vocab (Vocabulary): A storage container for lexical types. | |||||
doc (Document): The parent document. | |||||
offset (int): The index of the token within the document. | |||||
""" | |||||
self.vocab = vocab | |||||
self.doc = doc | |||||
self.token = doc[offset] | |||||
self.i = offset |
@@ -0,0 +1,135 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.optim as optim | |||||
from torch.autograd import Variable | |||||
from model.base_model import BaseModel | |||||
USE_GPU = True | |||||
def to_var(x): | |||||
if torch.cuda.is_available() and USE_GPU: | |||||
x = x.cuda() | |||||
return Variable(x) | |||||
class WordSegModel(BaseModel): | |||||
""" | |||||
Model controller for WordSeg | |||||
""" | |||||
def __init__(self): | |||||
super(WordSegModel, self).__init__() | |||||
self.id2word = None | |||||
self.word2id = None | |||||
self.id2tag = None | |||||
self.tag2id = None | |||||
self.lstm_batch_size = 8 | |||||
self.lstm_seq_len = 32 # Trainer batch_size == lstm_batch_size * lstm_seq_len | |||||
self.hidden_dim = 100 | |||||
self.lstm_num_layers = 2 | |||||
self.vocab_size = 100 | |||||
self.word_emb_dim = 100 | |||||
self.model = WordSeg(self.hidden_dim, self.lstm_num_layers, self.vocab_size, self.word_emb_dim) | |||||
self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)), | |||||
to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim))) | |||||
self.optimizer = None | |||||
self._loss = None | |||||
def prepare_input(self, data): | |||||
""" | |||||
perform word indices lookup to convert strings into indices | |||||
:param data: list of string, each string contains word + space + [B, M, E, S] | |||||
:return | |||||
""" | |||||
word_list = [] | |||||
tag_list = [] | |||||
for line in data: | |||||
if len(line) > 2: | |||||
tokens = line.split("#") | |||||
word_list.append(tokens[0]) | |||||
tag_list.append(tokens[2][0]) | |||||
self.id2word = list(set(word_list)) | |||||
self.word2id = {word: idx for idx, word in enumerate(self.id2word)} | |||||
self.id2tag = list(set(tag_list)) | |||||
self.tag2id = {tag: idx for idx, tag in enumerate(self.id2tag)} | |||||
words = np.array([self.word2id[w] for w in word_list]).reshape(-1, 1) | |||||
tags = np.array([self.tag2id[t] for t in tag_list]).reshape(-1, 1) | |||||
return words, tags | |||||
def mode(self, test=False): | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def data_forward(self, x): | |||||
""" | |||||
:param x: sequence of length [batch_size], word indices | |||||
:return: | |||||
""" | |||||
x = x.reshape(self.lstm_batch_size, self.lstm_seq_len) | |||||
output, self.hidden = self.model(x, self.hidden) | |||||
return output | |||||
def define_optimizer(self): | |||||
self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) | |||||
def get_loss(self, pred, truth): | |||||
self._loss = nn.CrossEntropyLoss(pred, truth) | |||||
return self._loss | |||||
def grad_backward(self): | |||||
self.model.zero_grad() | |||||
self._loss.backward() | |||||
torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) | |||||
self.optimizer.step() | |||||
class WordSeg(nn.Module): | |||||
""" | |||||
PyTorch Network for word segmentation | |||||
""" | |||||
def __init__(self, hidden_dim, lstm_num_layers, vocab_size, word_emb_dim=100): | |||||
super(WordSeg, self).__init__() | |||||
self.vocab_size = vocab_size | |||||
self.word_emb_dim = word_emb_dim | |||||
self.lstm_num_layers = lstm_num_layers | |||||
self.hidden_dim = hidden_dim | |||||
self.word_emb = nn.Embedding(self.vocab_size, self.word_emb_dim) | |||||
self.lstm = nn.LSTM(input_size=self.word_emb_dim, | |||||
hidden_size=self.word_emb_dim, | |||||
num_layers=self.lstm_num_layers, | |||||
bias=True, | |||||
dropout=0.5, | |||||
batch_first=True) | |||||
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | |||||
def forward(self, x, hidden): | |||||
""" | |||||
:param x: tensor of shape [batch_size, seq_len], vocabulary index | |||||
:param hidden: | |||||
:return x: probability of vocabulary entries | |||||
hidden: (memory cell, hidden state) from LSTM | |||||
""" | |||||
# [batch_size, seq_len] | |||||
x = self.word_emb(x) | |||||
# [batch_size, seq_len, word_emb_size] | |||||
x, hidden = self.lstm(x, hidden) | |||||
# [batch_size, seq_len, word_emb_size] | |||||
x = x.contiguous().view(x.shape[0] * x.shape[1], -1) | |||||
# [batch_size*seq_len, word_emb_size] | |||||
x = self.linear(x) | |||||
# [batch_size*seq_len, vocab_size] | |||||
return x, hidden |
@@ -1,5 +1,6 @@ | |||||
import os | import os | ||||
import | |||||
import | import | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -54,10 +55,10 @@ for epoch in range(num_epochs): | |||||
cnn.train() | cnn.train() | ||||
for i, (sents,labels) in enumerate(train_loader): | for i, (sents,labels) in enumerate(train_loader): | ||||
sents = Variable(sents) | sents = Variable(sents) | ||||
labels = Variable(labels) | |||||
if cuda: | |||||
sents = sents.cuda() | |||||
labels = labels.cuda() | |||||
labels = Variable(labels) | |||||
if cuda: | |||||
sents = sents.cuda() | |||||
labels = labels.cuda() | |||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
outputs = cnn(sents) | outputs = cnn(sents) | ||||
loss = criterion(outputs, labels) | loss = criterion(outputs, labels) | ||||
@@ -0,0 +1,30 @@ | |||||
from action.tester import Tester | |||||
from action.trainer import Trainer | |||||
from loader.base_loader import BaseLoader | |||||
from model.word_seg_model import WordSegModel | |||||
def test_charlm(): | |||||
train_config = Trainer.TrainConfig(epochs=5, validate=False, save_when_better=False, | |||||
log_per_step=10, log_validation=False, batch_size=254) | |||||
trainer = Trainer(train_config) | |||||
model = WordSegModel() | |||||
train_data = BaseLoader("load_train", "./data_for_tests/cws_train").load_lines() | |||||
trainer.train(model, train_data) | |||||
trainer.save_model(model) | |||||
test_config = Tester.TestConfig(save_output=False, validate_in_training=False, | |||||
save_dev_input=False, save_loss=False, batch_size=254) | |||||
tester = Tester(test_config) | |||||
test_data = BaseLoader("load_test", "./data_for_tests/cws_test").load_lines() | |||||
tester.test(model, test_data) | |||||
if __name__ == "__main__": | |||||
test_charlm() |