@@ -5,6 +5,7 @@ class Action(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Action, self).__init__() | super(Action, self).__init__() | ||||
self.logger = None | |||||
def load_config(self, args): | def load_config(self, args): | ||||
pass | pass | ||||
@@ -13,4 +14,15 @@ class Action(object): | |||||
pass | pass | ||||
def log(self, args): | def log(self, args): | ||||
self.logger.log(args) | |||||
""" | |||||
Basic operations shared between Trainer and Tester. | |||||
""" | |||||
def batchify(self, X, Y=None): | |||||
# a generator | |||||
pass | pass | ||||
def make_log(self, *args): | |||||
pass |
@@ -1,9 +1,45 @@ | |||||
import numpy as np | |||||
from action.action import Action | from action.action import Action | ||||
class Tester(Action): | class Tester(Action): | ||||
"""docstring for Tester""" | """docstring for Tester""" | ||||
def __init__(self, arg): | |||||
def __init__(self, test_args): | |||||
""" | |||||
:param test_args: named tuple | |||||
""" | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
self.arg = arg | |||||
self.test_args = test_args | |||||
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} | |||||
self.mean_loss = None | |||||
def test(self, network, data): | |||||
# transform into network input and label | |||||
X, Y = network.prepare_input(data) | |||||
# split into batches by self.batch_size | |||||
iterations, test_batch_generator = self.batchify(X, Y) | |||||
loss_history = list() | |||||
# turn on the testing mode of the network | |||||
network.mode(test=True) | |||||
for step in range(iterations): | |||||
batch_x, batch_y = test_batch_generator.__next__() | |||||
# forward pass from test input to predicted output | |||||
prediction = network.data_forward(batch_x) | |||||
# get the loss | |||||
loss = network.loss(batch_y, prediction) | |||||
loss_history.append(loss) | |||||
self.log(self.make_log(step, loss)) | |||||
self.mean_loss = np.mean(np.array(loss_history)) | |||||
@property | |||||
def loss(self): | |||||
return self.mean_loss |
@@ -0,0 +1,9 @@ | |||||
Some useful reference: | |||||
SpaCy "Doc" | |||||
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 | |||||
SpaCy "Vocab" | |||||
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/vocab.pyx#L25 | |||||
SpaCy "Token" | |||||
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/token.pyx#L27 |
@@ -1,15 +1,15 @@ | |||||
import os | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
import torch | import torch | ||||
from torch.autograd import Variable | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | |||||
import torch.optim as optim | import torch.optim as optim | ||||
import numpy as np | |||||
import os | |||||
from model import charLM | |||||
from utilities import * | |||||
from collections import namedtuple | |||||
from test import test | |||||
from torch.autograd import Variable | |||||
from .model import charLM | |||||
from .test import test | |||||
from .utilities import * | |||||
def preprocess(): | def preprocess(): | ||||
@@ -43,7 +43,18 @@ def to_var(x): | |||||
def train(net, data, opt): | def train(net, data, opt): | ||||
""" | |||||
:param net: the pytorch model | |||||
:param data: numpy array | |||||
:param opt: named tuple | |||||
1. random seed | |||||
2. define local input | |||||
3. training settting: learning rate, loss, etc | |||||
4. main loop epoch | |||||
5. batchify | |||||
6. validation | |||||
7. save model | |||||
""" | |||||
torch.manual_seed(1024) | torch.manual_seed(1024) | ||||
train_input = torch.from_numpy(data.train_input) | train_input = torch.from_numpy(data.train_input) | ||||
@@ -0,0 +1,14 @@ | |||||
class BaseSaver(object): | |||||
"""base class for all savers""" | |||||
def __init__(self, save_path): | |||||
self.save_path = save_path | |||||
def save_bytes(self): | |||||
pass | |||||
def save_str(self): | |||||
pass | |||||
def compress(self): | |||||
pass |
@@ -0,0 +1,11 @@ | |||||
from saver.base_saver import BaseSaver | |||||
class Logger(BaseSaver): | |||||
"""Logging""" | |||||
def __init__(self, save_path): | |||||
super(Logger, self).__init__(save_path) | |||||
def log(self, string): | |||||
pass |
@@ -0,0 +1,8 @@ | |||||
from saver.base_saver import BaseSaver | |||||
class ModelSaver(BaseSaver): | |||||
"""Save a model""" | |||||
def __init__(self, save_path): | |||||
super(ModelSaver, self).__init__(save_path) |