From c98d5924b585a7bfdc127e017d8cc2ff444d7e25 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 10 Jul 2018 20:46:35 +0800 Subject: [PATCH] sequence labeling ready to Train! --- fastNLP/action/trainer.py | 47 ++++++++++++------- ...encce_modeling.py => sequence_modeling.py} | 2 +- requirements.txt | 4 +- test/test_POS_pipeline.py | 9 ++-- 4 files changed, 39 insertions(+), 23 deletions(-) rename fastNLP/models/{sequencce_modeling.py => sequence_modeling.py} (98%) diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 1f22ef28..6f51435a 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -1,5 +1,4 @@ import _pickle -from collections import namedtuple import numpy as np import torch @@ -22,18 +21,22 @@ class BaseTrainer(Action): - grad_backward - get_loss """ - TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"]) def __init__(self, train_args): """ - training parameters + :param train_args: dict of (key, value) + + The base trainer requires the following keys: + - epochs: int, the number of epochs in training + - validate: bool, whether or not to validate on dev set + - batch_size: int + - pickle_path: str, the path to pickle files for pre-processing """ super(BaseTrainer, self).__init__() - self.train_args = train_args - self.n_epochs = train_args.epochs - # self.validate = train_args.validate - self.batch_size = train_args.batch_size - self.pickle_path = train_args.pickle_path + self.n_epochs = train_args["epochs"] + self.validate = train_args["validate"] + self.batch_size = train_args["batch_size"] + self.pickle_path = train_args["pickle_path"] self.model = None self.iterator = None self.loss_func = None @@ -66,8 +69,9 @@ class BaseTrainer(Action): for epoch in range(self.n_epochs): self.mode(test=False) - self.define_optimizer() + self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) + for step in range(iterations): batch_x, batch_y = self.batchify(self.batch_size, data_train) @@ -173,8 +177,6 @@ class BaseTrainer(Action): :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] """ - if self.iterator is None: - self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) indices = next(self.iterator) batch = [data[idx] for idx in indices] batch_x = [sample[0] for sample in batch] @@ -304,6 +306,7 @@ class WordSegTrainer(BaseTrainer): self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) def get_loss(self, predict, truth): + truth = torch.Tensor(truth) self._loss = torch.nn.CrossEntropyLoss(predict, truth) return self._loss @@ -316,13 +319,16 @@ class WordSegTrainer(BaseTrainer): self.optimizer.step() + class POSTrainer(BaseTrainer): - TrainConfig = namedtuple("config", ["epochs", "batch_size", "pickle_path", "num_classes", "vocab_size"]) + """ + Trainer for Sequence Modeling + """ def __init__(self, train_args): super(POSTrainer, self).__init__(train_args) - self.vocab_size = train_args.vocab_size - self.num_classes = train_args.num_classes + self.vocab_size = train_args["vocab_size"] + self.num_classes = train_args["num_classes"] self.max_len = None self.mask = None @@ -357,6 +363,13 @@ class POSTrainer(BaseTrainer): def define_optimizer(self): self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) + def grad_backward(self, loss): + self.model.zero_grad() + loss.backward() + + def update(self): + self.optimizer.step() + def get_loss(self, predict, truth): """ Compute loss given prediction and ground truth. @@ -364,16 +377,18 @@ class POSTrainer(BaseTrainer): :param truth: ground truth label vector, [batch_size, max_len] :return: a scalar """ + truth = torch.Tensor(truth) if self.loss_func is None: if hasattr(self.model, "loss"): self.loss_func = self.model.loss else: self.define_loss() - return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) + loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) + return loss if __name__ == "__name__": - train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") + train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} trainer = BaseTrainer(train_args) data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] trainer.batchify(batch_size=3, data=data_train) diff --git a/fastNLP/models/sequencce_modeling.py b/fastNLP/models/sequence_modeling.py similarity index 98% rename from fastNLP/models/sequencce_modeling.py rename to fastNLP/models/sequence_modeling.py index 96f09f80..80d13cf3 100644 --- a/fastNLP/models/sequencce_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -81,7 +81,7 @@ class SeqLabeling(BaseModel): x = x.float() y = y.long() mask = mask.byte() - print(x.shape, y.shape, mask.shape) + # print(x.shape, y.shape, mask.shape) if self.use_crf: total_loss = self.crf(x, y, mask) diff --git a/requirements.txt b/requirements.txt index 0fc94538..d961dd92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -numpy==1.14.2 +numpy>=1.14.2 torch==0.4.0 -torchvision==0.1.8 +torchvision>=0.1.8 diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index 66e418c6..c6e3fd83 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -5,7 +5,7 @@ sys.path.append("..") from fastNLP.action.trainer import POSTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader from fastNLP.loader.preprocess import POSPreprocess -from fastNLP.models.sequencce_modeling import SeqLabeling +from fastNLP.models.sequence_modeling import SeqLabeling data_name = "people.txt" data_path = "data_for_tests/people.txt" @@ -22,13 +22,14 @@ if __name__ == "__main__": num_classes = p.num_classes # Trainer - train_args = POSTrainer.TrainConfig(epochs=20, batch_size=1, num_classes=num_classes, - vocab_size=vocab_size, pickle_path=pickle_path) + train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, + "vocab_size": vocab_size, "pickle_path": pickle_path, "validate": False} trainer = POSTrainer(train_args) # Model model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) - # Start training. + # Start training trainer.train(model) + print("Training finished!")