Browse Source

update trainer: loading data with _pickle; add arguments comments.

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
7ea015c0f9
1 changed files with 29 additions and 22 deletions
  1. +29
    -22
      fastNLP/action/trainer.py

+ 29
- 22
fastNLP/action/trainer.py View File

@@ -1,4 +1,4 @@
import pickle
import _pickle
from collections import namedtuple from collections import namedtuple


import numpy as np import numpy as np
@@ -21,8 +21,7 @@ class BaseTrainer(Action):
- grad_backward - grad_backward
- get_loss - get_loss
""" """
TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better",
"log_per_step", "log_validation", "batch_size"])
TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"])


def __init__(self, train_args): def __init__(self, train_args):
""" """
@@ -32,6 +31,7 @@ class BaseTrainer(Action):
self.n_epochs = train_args.epochs self.n_epochs = train_args.epochs
self.validate = train_args.validate self.validate = train_args.validate
self.batch_size = train_args.batch_size self.batch_size = train_args.batch_size
self.pickle_path = train_args.pickle_path
self.model = None self.model = None
self.iterator = None self.iterator = None
self.loss_func = None self.loss_func = None
@@ -39,8 +39,6 @@ class BaseTrainer(Action):
def train(self, network): def train(self, network):
"""General training loop. """General training loop.
:param network: a model :param network: a model
:param train_data: raw data for training
:param dev_data: raw data for validation


The method is framework independent. The method is framework independent.
Work by calling the following methods: Work by calling the following methods:
@@ -54,7 +52,7 @@ class BaseTrainer(Action):
Subclasses must implement these methods with a specific framework. Subclasses must implement these methods with a specific framework.
""" """
self.model = network self.model = network
data_train, data_dev, data_test, embedding = self.prepare_input("./save/")
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path)


test_args = Tester.TestConfig(save_output=True, validate_in_training=True, test_args = Tester.TestConfig(save_output=True, validate_in_training=True,
save_dev_input=True, save_loss=True, batch_size=self.batch_size) save_dev_input=True, save_loss=True, batch_size=self.batch_size)
@@ -89,10 +87,10 @@ class BaseTrainer(Action):
""" """
To do: Load pkl files of train/dev/test and embedding To do: Load pkl files of train/dev/test and embedding
""" """
data_train = pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb"))
data_test = pickle.load(open(data_path + "data_test.pkl", "rb"))
embedding = pickle.load(open(data_path + "embedding.pkl", "rb"))
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
data_test = _pickle.load(open(data_path + "data_test.pkl", "rb"))
embedding = _pickle.load(open(data_path + "embedding.pkl", "rb"))
return data_train, data_dev, data_test, embedding return data_train, data_dev, data_test, embedding


def mode(self, test=False): def mode(self, test=False):
@@ -147,20 +145,30 @@ class BaseTrainer(Action):
if hasattr(self.model, "loss"): if hasattr(self.model, "loss"):
self.loss_func = self.model.loss self.loss_func = self.model.loss
else: else:
self.loss_func = self.define_loss()
self.define_loss()
return self.loss_func(predict, truth) return self.loss_func(predict, truth)


def define_loss(self): def define_loss(self):
"""
Assign an instance of loss function to self.loss_func
E.g. self.loss_func = nn.CrossEntropyLoss()
"""
raise NotImplementedError raise NotImplementedError


def batchify(self, batch_size, data): def batchify(self, batch_size, data):
""" """
Perform batching from data and produce a batch of training data.
Add padding.
:param batch_size:
:param data:
:param pad:
:return: batch_x, batch_y
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param batch_size: int, the size of a batch
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1
[[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2
...
]
:return batch_x: list. Each entry is a list of features of a sample.
batch_y: list. Each entry is a list of labels of a sample.
""" """
if self.iterator is None: if self.iterator is None:
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
@@ -306,8 +314,7 @@ class WordSegTrainer(BaseTrainer):




if __name__ == "__name__": if __name__ == "__name__":
Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step",
"log_validation", "batch_size"])
train_config = Config(epochs=5, validate=True, save_when_better=True, log_per_step=10, log_validation=True,
batch_size=32)
trainer = ToyTrainer(train_config)
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./")
trainer = BaseTrainer(train_args)
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
trainer.batchify(batch_size=3, data=data_train)

Loading…
Cancel
Save