|
|
@@ -31,12 +31,13 @@ class BaseTrainer(Action): |
|
|
|
super(BaseTrainer, self).__init__() |
|
|
|
self.train_args = train_args |
|
|
|
self.n_epochs = train_args.epochs |
|
|
|
self.validate = train_args.validate |
|
|
|
# self.validate = train_args.validate |
|
|
|
self.batch_size = train_args.batch_size |
|
|
|
self.pickle_path = train_args.pickle_path |
|
|
|
self.model = None |
|
|
|
self.iterator = None |
|
|
|
self.loss_func = None |
|
|
|
self.optimizer = None |
|
|
|
|
|
|
|
def train(self, network): |
|
|
|
"""General training loop. |
|
|
@@ -316,6 +317,8 @@ class WordSegTrainer(BaseTrainer): |
|
|
|
|
|
|
|
|
|
|
|
class POSTrainer(BaseTrainer): |
|
|
|
TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"]) |
|
|
|
|
|
|
|
def __init__(self, train_args): |
|
|
|
super(POSTrainer, self).__init__(train_args) |
|
|
|
self.vocab_size = train_args.vocab_size |
|
|
@@ -328,9 +331,9 @@ class POSTrainer(BaseTrainer): |
|
|
|
""" |
|
|
|
To do: Load pkl files of train/dev/test and embedding |
|
|
|
""" |
|
|
|
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) |
|
|
|
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) |
|
|
|
return data_train, data_dev |
|
|
|
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) |
|
|
|
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) |
|
|
|
return data_train, data_dev, 0, 1 |
|
|
|
|
|
|
|
def data_forward(self, network, x): |
|
|
|
seq_len = [len(seq) for seq in x] |
|
|
@@ -342,10 +345,28 @@ class POSTrainer(BaseTrainer): |
|
|
|
self.batch_x = x |
|
|
|
return x |
|
|
|
|
|
|
|
def mode(self, test=False): |
|
|
|
if test: |
|
|
|
self.model.eval() |
|
|
|
else: |
|
|
|
self.model.train() |
|
|
|
|
|
|
|
def define_optimizer(self): |
|
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) |
|
|
|
|
|
|
|
def get_loss(self, predict, truth): |
|
|
|
truth = torch.LongTensor(truth) |
|
|
|
loss, prediction = self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) |
|
|
|
return loss |
|
|
|
""" |
|
|
|
Compute loss given prediction and ground truth. |
|
|
|
:param predict: prediction label vector |
|
|
|
:param truth: ground truth label vector |
|
|
|
:return: a scalar |
|
|
|
""" |
|
|
|
if self.loss_func is None: |
|
|
|
if hasattr(self.model, "loss"): |
|
|
|
self.loss_func = self.model.loss |
|
|
|
else: |
|
|
|
self.define_loss() |
|
|
|
return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__name__": |
|
|
|