diff --git a/fastNLP/core/action.py b/fastNLP/core/action.py index d35a787c..d0dce5a6 100644 --- a/fastNLP/core/action.py +++ b/fastNLP/core/action.py @@ -4,20 +4,16 @@ """ from collections import Counter -import torch import numpy as np -import _pickle class Action(object): """ Operations shared by Trainer, Tester, and Inference. This is designed for reducing replicate codes. - - prepare_input: data preparation before a forward pass. - make_batch: produce a min-batch of data. @staticmethod - pad: padding method used in sequence modeling. @staticmethod - mode: change network mode for either train or test. (for PyTorch) @staticmethod - - data_forward: a forward pass of the network. The base Action shall define operations shared by as much task-specific Actions as possible. """ @@ -83,47 +79,6 @@ class Action(object): else: model.train() - def data_forward(self, network, x): - """ - Forward pass of the data. - :param network: a model - :param x: input feature matrix and label vector - :return: output by the models - - For PyTorch, just do "network(*x)" - """ - raise NotImplementedError - - -class SeqLabelAction(Action): - def __init__(self, action_args): - """ - Define task-specific member variables. - :param action_args: - """ - super(SeqLabelAction, self).__init__() - self.max_len = None - self.mask = None - self.best_accuracy = 0.0 - self.use_cuda = action_args["use_cuda"] - self.seq_len = None - self.batch_size = None - - def data_forward(self, network, inputs): - # unpack the returned value from make_batch - if isinstance(inputs, tuple): - x = inputs[0] - self.seq_len = inputs[1] - else: - x = inputs - x = torch.Tensor(x).long() - if torch.cuda.is_available() and self.use_cuda: - x = x.cuda() - self.batch_size = x.size(0) - self.max_len = x.size(1) - y = network(x) - return y - def k_means_1d(x, k, max_iter=100): """ diff --git a/fastNLP/core/inference.py b/fastNLP/core/inference.py index 7545a826..1bbcaf3a 100644 --- a/fastNLP/core/inference.py +++ b/fastNLP/core/inference.py @@ -1,7 +1,9 @@ +import numpy as np import torch from fastNLP.core.action import Batchifier, SequentialSampler from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL +from fastNLP.modules import utils class Inference(object): @@ -32,13 +34,14 @@ class Inference(object): # turn on the testing mode; clean up the history self.mode(network, test=True) + self.batch_output.clear() - self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) + iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) num_iter = len(data) // self.batch_size for step in range(num_iter): - batch_x = self.make_batch(data) + batch_x = self.make_batch(iterator, data) prediction = self.data_forward(network, batch_x) @@ -54,26 +57,18 @@ class Inference(object): self.batch_output.clear() def data_forward(self, network, x): - """ - This is only for sequence labeling with CRF decoder. TODO: more general ? - :param network: - :param x: - :return: - """ - seq_len = [len(seq) for seq in x] - x = torch.Tensor(x).long() - y = network(x) - prediction = network.prediction(y, seq_len) - # To do: hide framework - results = torch.Tensor(prediction).view(-1, ) - return list(results.data) + raise NotImplementedError - def make_batch(self, data): - indices = next(self.iterator) + @staticmethod + def make_batch(iterator, data, output_length=True): + indices = next(iterator) batch_x = [data[idx] for idx in indices] - if self.batch_size > 1: - batch_x = self.pad(batch_x) - return batch_x + batch_x_pad = Inference.pad(batch_x) + if output_length: + seq_len = [len(x) for x in batch_x] + return [batch_x_pad, seq_len] + else: + return batch_x_pad @staticmethod def pad(batch, fill=0): @@ -86,7 +81,7 @@ class Inference(object): max_length = max([len(x) for x in batch]) for idx, sample in enumerate(batch): if len(sample) < max_length: - batch[idx] = sample + [fill * (max_length - len(sample))] + batch[idx] = sample + ([fill] * (max_length - len(sample))) return batch def prepare_input(self, data): @@ -109,10 +104,39 @@ class Inference(object): def prepare_output(self, batch_outputs): """ Transform list of batch outputs into strings. - :param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor. + :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. :return: """ results = [] for batch in batch_outputs: - results.append([self.index2label[int(x.data)] for x in batch]) + for example in np.array(batch): + results.append([self.index2label[int(x)] for x in example]) return results + + +class SeqLabelInfer(Inference): + """ + Inference on sequence labeling models. + """ + + def __init__(self, pickle_path): + super(SeqLabelInfer, self).__init__(pickle_path) + + def data_forward(self, network, inputs): + """ + This is only for sequence labeling with CRF decoder. + :param network: + :param inputs: + :return: Tensor + """ + if not isinstance(inputs[1], list) and isinstance(inputs[0], list): + raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + # unpack the returned value from make_batch + x, seq_len = inputs[0], inputs[1] + x = torch.Tensor(x).long() + batch_size, max_len = x.size(0), x.size(1) + mask = utils.seq_mask(seq_len, max_len) + mask = mask.byte().view(batch_size, max_len) + y = network(x) + prediction = network.prediction(y, mask) + return torch.Tensor(prediction) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 8ee2ded6..27225cdb 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -6,17 +6,18 @@ import torch from fastNLP.core.action import Action from fastNLP.core.action import RandomSampler, Batchifier +from fastNLP.modules import utils class BaseTester(Action): """docstring for Tester""" - def __init__(self, test_args, action): + def __init__(self, test_args, action=None): """ :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" """ super(BaseTester, self).__init__() - self.action = action + self.action = action if action is not None else Action() self.validate_in_training = test_args["validate_in_training"] self.save_dev_data = None self.save_output = test_args["save_output"] @@ -52,7 +53,7 @@ class BaseTester(Action): for step in range(num_iter): batch_x, batch_y = self.action.make_batch(iterator, dev_data) - prediction = self.action.data_forward(network, batch_x) + prediction = self.data_forward(network, batch_x) eval_results = self.evaluate(prediction, batch_y) @@ -72,6 +73,9 @@ class BaseTester(Action): self.save_dev_data = data_dev return self.save_dev_data + def data_forward(self, network, x): + raise NotImplementedError + def evaluate(self, predict, truth): raise NotImplementedError @@ -92,7 +96,7 @@ class POSTester(BaseTester): Tester for sequence labeling. """ - def __init__(self, test_args, action): + def __init__(self, test_args, action=None): """ :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" """ @@ -101,17 +105,37 @@ class POSTester(BaseTester): self.mask = None self.batch_result = None + def data_forward(self, network, inputs): + if not isinstance(inputs, tuple): + raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + # unpack the returned value from make_batch + x, seq_len = inputs[0], inputs[1] + x = torch.Tensor(x).long() + batch_size, max_len = x.size(0), x.size(1) + mask = utils.seq_mask(seq_len, max_len) + mask = mask.byte().view(batch_size, max_len) + + if torch.cuda.is_available() and self.use_cuda: + x = x.cuda() + mask = mask.cuda() + self.mask = mask + + y = network(x) + return y + def evaluate(self, predict, truth): truth = torch.Tensor(truth) if torch.cuda.is_available() and self.use_cuda: truth = truth.cuda() - loss = self.model.loss(predict, truth, self.action.seq_len) / self.batch_size - prediction = self.model.prediction(predict, self.action.seq_len) + batch_size, max_len = predict.size(0), predict.size(1) + loss = self.model.loss(predict, truth, self.mask) / batch_size + + prediction = self.model.prediction(predict, self.mask) results = torch.Tensor(prediction).view(-1,) - if torch.cuda.is_available() and self.use_cuda: - results = results.cuda() - accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] - return [loss.data, accuracy] + # make sure "results" is in the same device as "truth" + results = results.to(truth) + accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0] + return [loss.data, accuracy.data] def metrics(self): batch_loss = np.mean([x[0] for x in self.eval_history]) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index df848d7d..d941536b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -8,8 +8,9 @@ import torch import torch.nn as nn from fastNLP.core.action import Action -from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler +from fastNLP.core.action import RandomSampler, Batchifier from fastNLP.core.tester import POSTester +from fastNLP.modules import utils from fastNLP.saver.model_saver import ModelSaver @@ -23,10 +24,10 @@ class BaseTrainer(Action): - get_loss """ - def __init__(self, train_args, action): + def __init__(self, train_args, action=None): """ :param train_args: dict of (key, value), or dict-like object. key is str. - :param action: an Action object that wrap most operations shared by Trainer, Tester, and Inference. + :param action: (optional) an Action object that wrap most operations shared by Trainer, Tester, and Inference. The base trainer requires the following keys: - epochs: int, the number of epochs in training @@ -35,7 +36,7 @@ class BaseTrainer(Action): - pickle_path: str, the path to pickle files for pre-processing """ super(BaseTrainer, self).__init__() - self.action = action + self.action = action if action is not None else Action() self.n_epochs = train_args["epochs"] self.batch_size = train_args["batch_size"] self.pickle_path = train_args["pickle_path"] @@ -94,7 +95,7 @@ class BaseTrainer(Action): for step in range(iterations): batch_x, batch_y = self.action.make_batch(iterator, data_train) - prediction = self.action.data_forward(network, batch_x) + prediction = self.data_forward(network, batch_x) loss = self.get_loss(prediction, batch_y) self.grad_backward(loss) @@ -137,6 +138,9 @@ class BaseTrainer(Action): """ raise NotImplementedError + def data_forward(self, network, x): + raise NotImplementedError + def grad_backward(self, loss): """ Compute gradient with link rules. @@ -223,7 +227,8 @@ class POSTrainer(BaseTrainer): Trainer for Sequence Modeling """ - def __init__(self, train_args, action): + + def __init__(self, train_args, action=None): super(POSTrainer, self).__init__(train_args, action) self.vocab_size = train_args["vocab_size"] self.num_classes = train_args["num_classes"] @@ -241,6 +246,24 @@ class POSTrainer(BaseTrainer): def update(self): self.optimizer.step() + def data_forward(self, network, inputs): + if not isinstance(inputs, tuple): + raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + # unpack the returned value from make_batch + x, seq_len = inputs[0], inputs[1] + batch_size, max_len = x.size(0), x.size(1) + mask = utils.seq_mask(seq_len, max_len) + mask = mask.byte().view(batch_size, max_len) + + x = torch.Tensor(x).long() + if torch.cuda.is_available() and self.use_cuda: + x = x.cuda() + mask = mask.cuda() + self.mask = mask + + y = network(x) + return y + def get_loss(self, predict, truth): """ Compute loss given prediction and ground truth. @@ -251,13 +274,10 @@ class POSTrainer(BaseTrainer): truth = torch.Tensor(truth) if torch.cuda.is_available() and self.use_cuda: truth = truth.cuda() - assert truth.shape == (self.batch_size, self.action.max_len) - if self.loss_func is None: - if hasattr(self.model, "loss"): - self.loss_func = self.model.loss - else: - self.define_loss() - loss = self.loss_func(predict, truth, self.action.seq_len) + batch_size, max_len = predict.size(0), predict.size(1) + assert truth.shape == (batch_size, max_len) + + loss = self.model.loss(predict, truth, self.mask) return loss def best_eval_result(self, validator): diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index 77a1f1d2..b28ef604 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -1,7 +1,7 @@ import torch from fastNLP.models.base_model import BaseModel -from fastNLP.modules import decoder, encoder, utils +from fastNLP.modules import decoder, encoder class SeqLabeling(BaseModel): @@ -34,46 +34,25 @@ class SeqLabeling(BaseModel): # [batch_size, max_len, num_classes] return x - def loss(self, x, y, seq_length): + def loss(self, x, y, mask): """ Negative log likelihood loss. - :param x: FloatTensor, [batch_size, max_len, tag_size] - :param y: LongTensor, [batch_size, max_len] - :param seq_length: list of int. [batch_size] + :param x: Tensor, [batch_size, max_len, tag_size] + :param y: Tensor, [batch_size, max_len] + :param mask: ByteTensor, [batch_size, ,max_len] :return loss: a scalar Tensor """ x = x.float() y = y.long() - - batch_size = x.size(0) - max_len = x.size(1) - - mask = utils.seq_mask(seq_length, max_len) - mask = mask.byte().view(batch_size, max_len) - - # TODO: remove - if torch.cuda.is_available(): - mask = mask.cuda() - # mask = x.new(batch_size, max_len) - total_loss = self.Crf(x, y, mask) - return torch.mean(total_loss) - def prediction(self, x, seq_length): + def prediction(self, x, mask): """ :param x: FloatTensor, [batch_size, max_len, tag_size] - :param seq_length: int - :return prediction: list of tuple of (decode path(list), best score) + :param mask: ByteTensor, [batch_size, max_len] + :return prediction: list of [decode path(list)] """ - x = x.float() - max_len = x.size(1) - - mask = utils.seq_mask(seq_length, max_len) - # hack: make sure mask has the same device as x - mask = mask.to(x).byte() - tag_seq = self.Crf.viterbi_decode(x, mask) - return tag_seq diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 5d8ce852..e6327ec0 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -132,6 +132,7 @@ class ConditionalRandomField(nn.Module): Given a feats matrix, return best decode path and best score. :param feats: :param masks: + :param get_score: bool, whether to output the decode score. :return:List[Tuple(List, float)], """ batch_size, max_len, tag_size = feats.size() diff --git a/test/seq_labeling.py b/test/seq_labeling.py index 10b9f986..ce31f0e8 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -2,7 +2,6 @@ import sys sys.path.append("..") -from fastNLP.core.action import SeqLabelAction from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.core.trainer import POSTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader @@ -11,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import POSTester from fastNLP.models.sequence_modeling import SeqLabeling -from fastNLP.core.inference import Inference +from fastNLP.core.inference import SeqLabelInfer data_name = "people.txt" data_path = "data_for_tests/people.txt" @@ -51,10 +50,11 @@ def infer(): """ # Inference interface - infer = Inference(pickle_path) + infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) - print(results) + for res in results: + print(res) print("Inference finished!") @@ -72,10 +72,8 @@ def train_and_test(): train_args["vocab_size"] = p.vocab_size train_args["num_classes"] = p.num_classes - action = SeqLabelAction(train_args) - # Trainer - trainer = POSTrainer(train_args, action) + trainer = POSTrainer(train_args) # Model model = SeqLabeling(train_args) @@ -103,7 +101,7 @@ def train_and_test(): ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) # Tester - tester = POSTester(test_args, action) + tester = POSTester(test_args) # Start testing tester.test(model) @@ -114,5 +112,5 @@ def train_and_test(): if __name__ == "__main__": - train_and_test() - + # train_and_test() + infer()