@@ -15,15 +15,26 @@ class Action(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def log(self, args): | def log(self, args): | ||||
self.logger.log(args) | |||||
""" | |||||
Basic operations shared between Trainer and Tester. | |||||
""" | |||||
print("call logger.log") | |||||
def batchify(self, X, Y=None): | def batchify(self, X, Y=None): | ||||
# a generator | |||||
raise NotImplementedError | |||||
""" | |||||
:param X: | |||||
:param Y: | |||||
:return iteration:int, the number of step in each epoch | |||||
generator:generator, to generate batch inputs | |||||
""" | |||||
data = X | |||||
if Y is not None: | |||||
data = [X, Y] | |||||
return 2, self._batch_generate(data) | |||||
def _batch_generate(self, data): | |||||
step = 10 | |||||
for i in range(2): | |||||
start = i * step | |||||
end = (i + 1) * step | |||||
yield data[0][start:end], data[1][start:end] | |||||
def make_log(self, *args): | def make_log(self, *args): | ||||
raise NotImplementedError | |||||
return "log" |
@@ -12,7 +12,7 @@ class Tester(Action): | |||||
""" | """ | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
self.test_args = test_args | self.test_args = test_args | ||||
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||||
# self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||||
self.mean_loss = None | self.mean_loss = None | ||||
self.output = None | self.output = None | ||||
@@ -54,4 +54,4 @@ class Tester(Action): | |||||
def make_output(self, batch_output): | def make_output(self, batch_output): | ||||
# construct full prediction with batch outputs | # construct full prediction with batch outputs | ||||
raise NotImplementedError | |||||
return np.concatenate((batch_output[0], batch_output[1]), axis=0) |
@@ -1,5 +1,5 @@ | |||||
from action.action import Action | |||||
from action.tester import Tester | |||||
from .action import Action | |||||
from .tester import Tester | |||||
class Trainer(Action): | class Trainer(Action): | ||||
@@ -13,10 +13,10 @@ class Trainer(Action): | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
self.train_args = train_args | self.train_args = train_args | ||||
self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||||
# self.args_dict = {name: value for name, value in self.train_args.__dict__.iteritems()} | |||||
self.n_epochs = self.train_args.epochs | self.n_epochs = self.train_args.epochs | ||||
self.validate = True | |||||
self.save_when_better = True | |||||
self.validate = self.train_args.validate | |||||
self.save_when_better = self.train_args.save_when_better | |||||
def train(self, network, data, dev_data): | def train(self, network, data, dev_data): | ||||
X, Y = network.prepare_input(data) | X, Y = network.prepare_input(data) | ||||
@@ -51,10 +51,10 @@ class Trainer(Action): | |||||
# finish training | # finish training | ||||
def make_log(self, *args): | def make_log(self, *args): | ||||
raise NotImplementedError | |||||
print("logged") | |||||
def make_valid_log(self, *args): | def make_valid_log(self, *args): | ||||
raise NotImplementedError | |||||
print("logged") | |||||
def save_model(self, model): | def save_model(self, model): | ||||
raise NotImplementedError | |||||
print("model saved") |
@@ -1,3 +1,6 @@ | |||||
import numpy as np | |||||
class BaseModel(object): | class BaseModel(object): | ||||
"""base model for all models""" | """base model for all models""" | ||||
@@ -5,6 +8,10 @@ class BaseModel(object): | |||||
pass | pass | ||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
""" | |||||
:param data: str, raw input vector(?) | |||||
:return (X, Y): tuple, input features and labels | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def mode(self, test=False): | def mode(self, test=False): | ||||
@@ -20,6 +27,33 @@ class BaseModel(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class ToyModel(BaseModel): | |||||
"""This is for code testing.""" | |||||
def __init__(self): | |||||
super(ToyModel, self).__init__() | |||||
self.test_mode = False | |||||
self.weight = np.random.rand(5, 1) | |||||
self.bias = np.random.rand() | |||||
self._loss = 0 | |||||
def prepare_input(self, data): | |||||
return data[:, :-1], data[:, -1] | |||||
def mode(self, test=False): | |||||
self.test_mode = test | |||||
def data_forward(self, x): | |||||
return np.matmul(x, self.weight) + self.bias | |||||
def grad_backward(self): | |||||
print("loss gradient backward") | |||||
def loss(self, pred, truth): | |||||
self._loss = np.mean(np.square(pred - truth)) | |||||
return self._loss | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
""" | """ | ||||
A collection of lookup tables. | A collection of lookup tables. | ||||
@@ -0,0 +1,21 @@ | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
from action.trainer import Trainer | |||||
from model.base_model import ToyModel | |||||
def test_trainer(): | |||||
Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | |||||
train_config = Config(epochs=5, validate=True, save_when_better=True) | |||||
trainer = Trainer(train_config) | |||||
net = ToyModel() | |||||
data = np.random.rand(20, 6) | |||||
dev_data = np.random.rand(20, 6) | |||||
trainer.train(net, data, dev_data) | |||||
if __name__ == "__main__": | |||||
test_trainer() |