Browse Source

build classes for saver

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
4f71d44999
8 changed files with 112 additions and 11 deletions
  1. +12
    -0
      action/action.py
  2. +38
    -2
      action/tester.py
  3. +9
    -0
      model/empty.txt
  4. +20
    -9
      reproduction/Char-aware_NLM/train.py
  5. +14
    -0
      saver/base_saver.py
  6. +0
    -0
      saver/empty.txt
  7. +11
    -0
      saver/logger.py
  8. +8
    -0
      saver/model_saver.py

+ 12
- 0
action/action.py View File

@@ -5,6 +5,7 @@ class Action(object):

def __init__(self):
super(Action, self).__init__()
self.logger = None

def load_config(self, args):
pass
@@ -13,4 +14,15 @@ class Action(object):
pass

def log(self, args):
self.logger.log(args)

"""
Basic operations shared between Trainer and Tester.
"""

def batchify(self, X, Y=None):
# a generator
pass

def make_log(self, *args):
pass

+ 38
- 2
action/tester.py View File

@@ -1,9 +1,45 @@
import numpy as np

from action.action import Action


class Tester(Action):
"""docstring for Tester"""

def __init__(self, arg):
def __init__(self, test_args):
"""
:param test_args: named tuple
"""
super(Tester, self).__init__()
self.arg = arg
self.test_args = test_args
self.args_dict = {name: value for name, value in self.test_args.__dict__.iteritems()}
self.mean_loss = None

def test(self, network, data):
# transform into network input and label
X, Y = network.prepare_input(data)

# split into batches by self.batch_size
iterations, test_batch_generator = self.batchify(X, Y)

loss_history = list()
# turn on the testing mode of the network
network.mode(test=True)

for step in range(iterations):
batch_x, batch_y = test_batch_generator.__next__()

# forward pass from test input to predicted output
prediction = network.data_forward(batch_x)

# get the loss
loss = network.loss(batch_y, prediction)

loss_history.append(loss)
self.log(self.make_log(step, loss))

self.mean_loss = np.mean(np.array(loss_history))

@property
def loss(self):
return self.mean_loss

+ 9
- 0
model/empty.txt View File

@@ -0,0 +1,9 @@
Some useful reference:
SpaCy "Doc"
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80

SpaCy "Vocab"
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/vocab.pyx#L25

SpaCy "Token"
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/token.pyx#L27

+ 20
- 9
reproduction/Char-aware_NLM/train.py View File

@@ -1,15 +1,15 @@
import os
from collections import namedtuple

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
from model import charLM
from utilities import *
from collections import namedtuple
from test import test
from torch.autograd import Variable

from .model import charLM
from .test import test
from .utilities import *


def preprocess():
@@ -43,7 +43,18 @@ def to_var(x):


def train(net, data, opt):
"""
:param net: the pytorch model
:param data: numpy array
:param opt: named tuple
1. random seed
2. define local input
3. training settting: learning rate, loss, etc
4. main loop epoch
5. batchify
6. validation
7. save model
"""
torch.manual_seed(1024)

train_input = torch.from_numpy(data.train_input)


+ 14
- 0
saver/base_saver.py View File

@@ -0,0 +1,14 @@
class BaseSaver(object):
"""base class for all savers"""

def __init__(self, save_path):
self.save_path = save_path

def save_bytes(self):
pass

def save_str(self):
pass

def compress(self):
pass

+ 0
- 0
saver/empty.txt View File


+ 11
- 0
saver/logger.py View File

@@ -0,0 +1,11 @@
from saver.base_saver import BaseSaver


class Logger(BaseSaver):
"""Logging"""

def __init__(self, save_path):
super(Logger, self).__init__(save_path)

def log(self, string):
pass

+ 8
- 0
saver/model_saver.py View File

@@ -0,0 +1,8 @@
from saver.base_saver import BaseSaver


class ModelSaver(BaseSaver):
"""Save a model"""

def __init__(self, save_path):
super(ModelSaver, self).__init__(save_path)

Loading…
Cancel
Save