|
@@ -1,9 +1,11 @@ |
|
|
|
|
|
import pickle |
|
|
from collections import namedtuple |
|
|
from collections import namedtuple |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
from fastNLP.action.action import Action |
|
|
from fastNLP.action.action import Action |
|
|
|
|
|
from fastNLP.action.action import RandomSampler, Batchifier |
|
|
from fastNLP.action.tester import Tester |
|
|
from fastNLP.action.tester import Tester |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -31,8 +33,10 @@ class BaseTrainer(Action): |
|
|
self.validate = train_args.validate |
|
|
self.validate = train_args.validate |
|
|
self.batch_size = train_args.batch_size |
|
|
self.batch_size = train_args.batch_size |
|
|
self.model = None |
|
|
self.model = None |
|
|
|
|
|
self.iterator = None |
|
|
|
|
|
self.loss_func = None |
|
|
|
|
|
|
|
|
def train(self, network, train_data, dev_data=None): |
|
|
|
|
|
|
|
|
def train(self, network): |
|
|
"""General training loop. |
|
|
"""General training loop. |
|
|
:param network: a model |
|
|
:param network: a model |
|
|
:param train_data: raw data for training |
|
|
:param train_data: raw data for training |
|
@@ -50,22 +54,21 @@ class BaseTrainer(Action): |
|
|
Subclasses must implement these methods with a specific framework. |
|
|
Subclasses must implement these methods with a specific framework. |
|
|
""" |
|
|
""" |
|
|
self.model = network |
|
|
self.model = network |
|
|
train_x, train_y = self.prepare_input(train_data) |
|
|
|
|
|
|
|
|
|
|
|
iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y) |
|
|
|
|
|
|
|
|
data_train, data_dev, data_test, embedding = self.prepare_input("./save/") |
|
|
|
|
|
|
|
|
test_args = Tester.TestConfig(save_output=True, validate_in_training=True, |
|
|
test_args = Tester.TestConfig(save_output=True, validate_in_training=True, |
|
|
save_dev_input=True, save_loss=True, batch_size=self.batch_size) |
|
|
save_dev_input=True, save_loss=True, batch_size=self.batch_size) |
|
|
evaluator = Tester(test_args) |
|
|
evaluator = Tester(test_args) |
|
|
|
|
|
|
|
|
best_loss = 1e10 |
|
|
best_loss = 1e10 |
|
|
|
|
|
iterations = len(data_train) // self.batch_size |
|
|
|
|
|
|
|
|
for epoch in range(self.n_epochs): |
|
|
for epoch in range(self.n_epochs): |
|
|
self.mode(test=False) # turn on the train mode |
|
|
|
|
|
|
|
|
self.mode(test=False) |
|
|
|
|
|
|
|
|
self.define_optimizer() |
|
|
self.define_optimizer() |
|
|
for step in range(iterations): |
|
|
for step in range(iterations): |
|
|
batch_x, batch_y = train_batch_generator.__next__() |
|
|
|
|
|
|
|
|
batch_x, batch_y = self.batchify(self.batch_size, data_train) |
|
|
|
|
|
|
|
|
prediction = self.data_forward(network, batch_x) |
|
|
prediction = self.data_forward(network, batch_x) |
|
|
|
|
|
|
|
@@ -74,21 +77,23 @@ class BaseTrainer(Action): |
|
|
self.update() |
|
|
self.update() |
|
|
|
|
|
|
|
|
if self.validate: |
|
|
if self.validate: |
|
|
if dev_data is None: |
|
|
|
|
|
|
|
|
if data_dev is None: |
|
|
raise RuntimeError("No validation data provided.") |
|
|
raise RuntimeError("No validation data provided.") |
|
|
evaluator.test(network, dev_data) |
|
|
|
|
|
|
|
|
evaluator.test(network, data_dev) |
|
|
if evaluator.loss < best_loss: |
|
|
if evaluator.loss < best_loss: |
|
|
best_loss = evaluator.loss |
|
|
best_loss = evaluator.loss |
|
|
|
|
|
|
|
|
# finish training |
|
|
# finish training |
|
|
|
|
|
|
|
|
def prepare_input(self, data): |
|
|
|
|
|
|
|
|
def prepare_input(self, data_path): |
|
|
""" |
|
|
""" |
|
|
Perform data transformation from raw input to vector/matrix inputs. |
|
|
|
|
|
:param data: raw inputs |
|
|
|
|
|
:return (X, Y): tuple, input features and labels |
|
|
|
|
|
|
|
|
To do: Load pkl files of train/dev/test and embedding |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
data_train = pickle.load(open(data_path + "data_train.pkl", "rb")) |
|
|
|
|
|
data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb")) |
|
|
|
|
|
data_test = pickle.load(open(data_path + "data_test.pkl", "rb")) |
|
|
|
|
|
embedding = pickle.load(open(data_path + "embedding.pkl", "rb")) |
|
|
|
|
|
return data_train, data_dev, data_test, embedding |
|
|
|
|
|
|
|
|
def mode(self, test=False): |
|
|
def mode(self, test=False): |
|
|
""" |
|
|
""" |
|
@@ -138,8 +143,48 @@ class BaseTrainer(Action): |
|
|
:param truth: ground truth label vector |
|
|
:param truth: ground truth label vector |
|
|
:return: a scalar |
|
|
:return: a scalar |
|
|
""" |
|
|
""" |
|
|
|
|
|
if self.loss_func is None: |
|
|
|
|
|
if hasattr(self.model, "loss"): |
|
|
|
|
|
self.loss_func = self.model.loss |
|
|
|
|
|
else: |
|
|
|
|
|
self.loss_func = self.define_loss() |
|
|
|
|
|
return self.loss_func(predict, truth) |
|
|
|
|
|
|
|
|
|
|
|
def define_loss(self): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
def batchify(self, batch_size, data): |
|
|
|
|
|
""" |
|
|
|
|
|
Perform batching from data and produce a batch of training data. |
|
|
|
|
|
Add padding. |
|
|
|
|
|
:param batch_size: |
|
|
|
|
|
:param data: |
|
|
|
|
|
:param pad: |
|
|
|
|
|
:return: batch_x, batch_y |
|
|
|
|
|
""" |
|
|
|
|
|
if self.iterator is None: |
|
|
|
|
|
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) |
|
|
|
|
|
indices = next(self.iterator) |
|
|
|
|
|
batch = [data[idx] for idx in indices] |
|
|
|
|
|
batch_x = [sample[0] for sample in batch] |
|
|
|
|
|
batch_y = [sample[1] for sample in batch] |
|
|
|
|
|
batch_x = self.pad(batch_x) |
|
|
|
|
|
return batch_x, batch_y |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def pad(batch, fill=0): |
|
|
|
|
|
""" |
|
|
|
|
|
Pad a batch of samples to maximum length. |
|
|
|
|
|
:param batch: list of list |
|
|
|
|
|
:param fill: word index to pad, default 0. |
|
|
|
|
|
:return: a padded batch |
|
|
|
|
|
""" |
|
|
|
|
|
max_length = max([len(x) for x in batch]) |
|
|
|
|
|
for idx, sample in enumerate(batch): |
|
|
|
|
|
if len(sample) < max_length: |
|
|
|
|
|
batch[idx] = sample + [fill * (max_length - len(sample))] |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ToyTrainer(BaseTrainer): |
|
|
class ToyTrainer(BaseTrainer): |
|
|
"""A simple trainer for a PyTorch model.""" |
|
|
"""A simple trainer for a PyTorch model.""" |
|
|