diff --git a/action/action.py b/action/action.py index 836445e0..40ec3142 100644 --- a/action/action.py +++ b/action/action.py @@ -5,6 +5,7 @@ class Action(object): def __init__(self): super(Action, self).__init__() + self.logger = None def load_config(self, args): pass @@ -13,4 +14,15 @@ class Action(object): pass def log(self, args): + self.logger.log(args) + + """ + Basic operations shared between Trainer and Tester. + """ + + def batchify(self, X, Y=None): + # a generator pass + + def make_log(self, *args): + pass \ No newline at end of file diff --git a/action/tester.py b/action/tester.py index 91bf28aa..0b78a782 100644 --- a/action/tester.py +++ b/action/tester.py @@ -1,9 +1,45 @@ +import numpy as np + from action.action import Action class Tester(Action): """docstring for Tester""" - def __init__(self, arg): + def __init__(self, test_args): + """ + :param test_args: named tuple + """ super(Tester, self).__init__() - self.arg = arg + self.test_args = test_args + self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()} + self.mean_loss = None + + def test(self, network, data): + # transform into network input and label + X, Y = network.prepare_input(data) + + # split into batches by self.batch_size + iterations, test_batch_generator = self.batchify(X, Y) + + loss_history = list() + # turn on the testing mode of the network + network.mode(test=True) + + for step in range(iterations): + batch_x, batch_y = test_batch_generator.__next__() + + # forward pass from test input to predicted output + prediction = network.data_forward(batch_x) + + # get the loss + loss = network.loss(batch_y, prediction) + + loss_history.append(loss) + self.log(self.make_log(step, loss)) + + self.mean_loss = np.mean(np.array(loss_history)) + + @property + def loss(self): + return self.mean_loss diff --git a/model/empty.txt b/model/empty.txt index e69de29b..942340d2 100644 --- a/model/empty.txt +++ b/model/empty.txt @@ -0,0 +1,9 @@ +Some useful reference: +SpaCy "Doc" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 + +SpaCy "Vocab" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/vocab.pyx#L25 + +SpaCy "Token" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/token.pyx#L27 diff --git a/reproduction/Char-aware_NLM/train.py b/reproduction/Char-aware_NLM/train.py index caab0adf..99edb3c6 100644 --- a/reproduction/Char-aware_NLM/train.py +++ b/reproduction/Char-aware_NLM/train.py @@ -1,15 +1,15 @@ +import os +from collections import namedtuple +import numpy as np import torch -from torch.autograd import Variable import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim -import numpy as np -import os -from model import charLM -from utilities import * -from collections import namedtuple -from test import test +from torch.autograd import Variable + +from .model import charLM +from .test import test +from .utilities import * def preprocess(): @@ -43,7 +43,18 @@ def to_var(x): def train(net, data, opt): - + """ + :param net: the pytorch model + :param data: numpy array + :param opt: named tuple + 1. random seed + 2. define local input + 3. training settting: learning rate, loss, etc + 4. main loop epoch + 5. batchify + 6. validation + 7. save model + """ torch.manual_seed(1024) train_input = torch.from_numpy(data.train_input) diff --git a/saver/base_saver.py b/saver/base_saver.py new file mode 100644 index 00000000..d89e0935 --- /dev/null +++ b/saver/base_saver.py @@ -0,0 +1,14 @@ +class BaseSaver(object): + """base class for all savers""" + + def __init__(self, save_path): + self.save_path = save_path + + def save_bytes(self): + pass + + def save_str(self): + pass + + def compress(self): + pass diff --git a/saver/empty.txt b/saver/empty.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/saver/logger.py b/saver/logger.py new file mode 100644 index 00000000..9ff66866 --- /dev/null +++ b/saver/logger.py @@ -0,0 +1,11 @@ +from saver.base_saver import BaseSaver + + +class Logger(BaseSaver): + """Logging""" + + def __init__(self, save_path): + super(Logger, self).__init__(save_path) + + def log(self, string): + pass diff --git a/saver/model_saver.py b/saver/model_saver.py new file mode 100644 index 00000000..4bc6ea34 --- /dev/null +++ b/saver/model_saver.py @@ -0,0 +1,8 @@ +from saver.base_saver import BaseSaver + + +class ModelSaver(BaseSaver): + """Save a model""" + + def __init__(self, save_path): + super(ModelSaver, self).__init__(save_path)