|
@@ -1,5 +1,4 @@ |
|
|
import _pickle |
|
|
import _pickle |
|
|
from collections import namedtuple |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
@@ -22,18 +21,22 @@ class BaseTrainer(Action): |
|
|
- grad_backward |
|
|
- grad_backward |
|
|
- get_loss |
|
|
- get_loss |
|
|
""" |
|
|
""" |
|
|
TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"]) |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, train_args): |
|
|
def __init__(self, train_args): |
|
|
""" |
|
|
""" |
|
|
training parameters |
|
|
|
|
|
|
|
|
:param train_args: dict of (key, value) |
|
|
|
|
|
|
|
|
|
|
|
The base trainer requires the following keys: |
|
|
|
|
|
- epochs: int, the number of epochs in training |
|
|
|
|
|
- validate: bool, whether or not to validate on dev set |
|
|
|
|
|
- batch_size: int |
|
|
|
|
|
- pickle_path: str, the path to pickle files for pre-processing |
|
|
""" |
|
|
""" |
|
|
super(BaseTrainer, self).__init__() |
|
|
super(BaseTrainer, self).__init__() |
|
|
self.train_args = train_args |
|
|
|
|
|
self.n_epochs = train_args.epochs |
|
|
|
|
|
# self.validate = train_args.validate |
|
|
|
|
|
self.batch_size = train_args.batch_size |
|
|
|
|
|
self.pickle_path = train_args.pickle_path |
|
|
|
|
|
|
|
|
self.n_epochs = train_args["epochs"] |
|
|
|
|
|
self.validate = train_args["validate"] |
|
|
|
|
|
self.batch_size = train_args["batch_size"] |
|
|
|
|
|
self.pickle_path = train_args["pickle_path"] |
|
|
self.model = None |
|
|
self.model = None |
|
|
self.iterator = None |
|
|
self.iterator = None |
|
|
self.loss_func = None |
|
|
self.loss_func = None |
|
@@ -66,8 +69,9 @@ class BaseTrainer(Action): |
|
|
|
|
|
|
|
|
for epoch in range(self.n_epochs): |
|
|
for epoch in range(self.n_epochs): |
|
|
self.mode(test=False) |
|
|
self.mode(test=False) |
|
|
|
|
|
|
|
|
self.define_optimizer() |
|
|
self.define_optimizer() |
|
|
|
|
|
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) |
|
|
|
|
|
|
|
|
for step in range(iterations): |
|
|
for step in range(iterations): |
|
|
batch_x, batch_y = self.batchify(self.batch_size, data_train) |
|
|
batch_x, batch_y = self.batchify(self.batch_size, data_train) |
|
|
|
|
|
|
|
@@ -173,8 +177,6 @@ class BaseTrainer(Action): |
|
|
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] |
|
|
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] |
|
|
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] |
|
|
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] |
|
|
""" |
|
|
""" |
|
|
if self.iterator is None: |
|
|
|
|
|
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) |
|
|
|
|
|
indices = next(self.iterator) |
|
|
indices = next(self.iterator) |
|
|
batch = [data[idx] for idx in indices] |
|
|
batch = [data[idx] for idx in indices] |
|
|
batch_x = [sample[0] for sample in batch] |
|
|
batch_x = [sample[0] for sample in batch] |
|
@@ -304,6 +306,7 @@ class WordSegTrainer(BaseTrainer): |
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) |
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) |
|
|
|
|
|
|
|
|
def get_loss(self, predict, truth): |
|
|
def get_loss(self, predict, truth): |
|
|
|
|
|
truth = torch.Tensor(truth) |
|
|
self._loss = torch.nn.CrossEntropyLoss(predict, truth) |
|
|
self._loss = torch.nn.CrossEntropyLoss(predict, truth) |
|
|
return self._loss |
|
|
return self._loss |
|
|
|
|
|
|
|
@@ -316,13 +319,16 @@ class WordSegTrainer(BaseTrainer): |
|
|
self.optimizer.step() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class POSTrainer(BaseTrainer): |
|
|
class POSTrainer(BaseTrainer): |
|
|
TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"]) |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
Trainer for Sequence Modeling |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
def __init__(self, train_args): |
|
|
def __init__(self, train_args): |
|
|
super(POSTrainer, self).__init__(train_args) |
|
|
super(POSTrainer, self).__init__(train_args) |
|
|
self.vocab_size = train_args.vocab_size |
|
|
|
|
|
self.num_classes = train_args.num_classes |
|
|
|
|
|
|
|
|
self.vocab_size = train_args["vocab_size"] |
|
|
|
|
|
self.num_classes = train_args["num_classes"] |
|
|
self.max_len = None |
|
|
self.max_len = None |
|
|
self.mask = None |
|
|
self.mask = None |
|
|
|
|
|
|
|
@@ -357,6 +363,13 @@ class POSTrainer(BaseTrainer): |
|
|
def define_optimizer(self): |
|
|
def define_optimizer(self): |
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) |
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) |
|
|
|
|
|
|
|
|
|
|
|
def grad_backward(self, loss): |
|
|
|
|
|
self.model.zero_grad() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
def update(self): |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
def get_loss(self, predict, truth): |
|
|
def get_loss(self, predict, truth): |
|
|
""" |
|
|
""" |
|
|
Compute loss given prediction and ground truth. |
|
|
Compute loss given prediction and ground truth. |
|
@@ -364,16 +377,18 @@ class POSTrainer(BaseTrainer): |
|
|
:param truth: ground truth label vector, [batch_size, max_len] |
|
|
:param truth: ground truth label vector, [batch_size, max_len] |
|
|
:return: a scalar |
|
|
:return: a scalar |
|
|
""" |
|
|
""" |
|
|
|
|
|
truth = torch.Tensor(truth) |
|
|
if self.loss_func is None: |
|
|
if self.loss_func is None: |
|
|
if hasattr(self.model, "loss"): |
|
|
if hasattr(self.model, "loss"): |
|
|
self.loss_func = self.model.loss |
|
|
self.loss_func = self.model.loss |
|
|
else: |
|
|
else: |
|
|
self.define_loss() |
|
|
self.define_loss() |
|
|
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) |
|
|
|
|
|
|
|
|
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__name__": |
|
|
if __name__ == "__name__": |
|
|
train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") |
|
|
|
|
|
|
|
|
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} |
|
|
trainer = BaseTrainer(train_args) |
|
|
trainer = BaseTrainer(train_args) |
|
|
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] |
|
|
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] |
|
|
trainer.batchify(batch_size=3, data=data_train) |
|
|
trainer.batchify(batch_size=3, data=data_train) |