From 4bbeaebe96e63bc7067bcf53c3445b1b0a001372 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Wed, 15 Aug 2018 20:12:20 +0800 Subject: [PATCH] Updates to cores, action, loader: - rename Inference to Predictor - rename Trainer.prepare_input to Trainer.load_train_data, load data_train.pkl only - add __contains__ method to config Section class - more code comments - more elegant make_batch & data_iterator: Samplers return batch samples instead of batch indices --- fastNLP/action/optimizer.py | 5 - fastNLP/core/action.py | 115 ++++++++++++-------- fastNLP/core/metrics.py | 62 +++++------ fastNLP/core/optimizer.py | 6 +- fastNLP/core/{inference.py => predictor.py} | 92 +++++++++------- fastNLP/core/tester.py | 39 +++++-- fastNLP/core/trainer.py | 88 +++++++-------- fastNLP/fastnlp.py | 2 +- fastNLP/loader/config_loader.py | 3 + fastNLP/loader/embed_loader.py | 2 +- fastNLP/modules/utils.py | 10 +- reproduction/chinese_word_seg/cws_train.py | 4 +- test/__init__.py | 2 + test/ner_decode.py | 2 +- test/seq_labeling.py | 6 +- test/test_cws.py | 4 +- test/text_classify.py | 4 +- 17 files changed, 251 insertions(+), 195 deletions(-) delete mode 100644 fastNLP/action/optimizer.py rename fastNLP/core/{inference.py => predictor.py} (62%) diff --git a/fastNLP/action/optimizer.py b/fastNLP/action/optimizer.py deleted file mode 100644 index b493e3f0..00000000 --- a/fastNLP/action/optimizer.py +++ /dev/null @@ -1,5 +0,0 @@ -''' -use optimizer from Pytorch -''' - -from torch.optim import * \ No newline at end of file diff --git a/fastNLP/core/action.py b/fastNLP/core/action.py index 560fa42e..358db499 100644 --- a/fastNLP/core/action.py +++ b/fastNLP/core/action.py @@ -10,7 +10,7 @@ import torch class Action(object): """ - Operations shared by Trainer, Tester, and Inference. + Operations shared by Trainer, Tester, or Inference. This is designed for reducing replicate codes. - make_batch: produce a min-batch of data. @staticmethod - pad: padding method used in sequence modeling. @staticmethod @@ -22,28 +22,24 @@ class Action(object): super(Action, self).__init__() @staticmethod - def make_batch(iterator, data, use_cuda, output_length=True, max_len=None): + def make_batch(iterator, use_cuda, output_length=True, max_len=None): """Batch and Pad data. :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. - :param data: list. Each entry is a sample, which is also a list of features and label(s). - E.g. - [ - [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 - [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 - ... - ] - :param use_cuda: bool - :param output_length: whether to output the original length of the sequence before padding. - :param max_len: int, maximum sequence length - :return (batch_x, seq_len): tuple of two elements, if output_length is true. + :param use_cuda: bool, whether to use GPU + :param output_length: bool, whether to output the original length of the sequence before padding. (default: True) + :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) + :return + if output_length is True: + (batch_x, seq_len): tuple of two elements batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] seq_len: list. The length of the pre-padded sequence, if output_length is True. - batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] + batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] - return batch_x and batch_y, if output_length is False + if output_length is False: + batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] """ - for indices in iterator: - batch = [data[idx] for idx in indices] + for batch in iterator: batch_x = [sample[0] for sample in batch] batch_y = [sample[1] for sample in batch] @@ -68,11 +64,11 @@ class Action(object): @staticmethod def pad(batch, fill=0): - """ - Pad a batch of samples to maximum length of this batch. + """ Pad a mini-batch of sequence samples to maximum length of this batch. + :param batch: list of list :param fill: word index to pad, default 0. - :return: a padded batch + :return batch: a padded mini-batch """ max_length = max([len(x) for x in batch]) for idx, sample in enumerate(batch): @@ -95,11 +91,10 @@ class Action(object): def convert_to_torch_tensor(data_list, use_cuda): """ - convert lists into (cuda) Tensors + convert lists into (cuda) Tensors. :param data_list: 2-level lists - :param use_cuda: bool - :param reqired_grad: bool - :return: PyTorch Tensor of shape [batch_size, max_seq_len] + :param use_cuda: bool, whether to use GPU or not + :return data_list: PyTorch Tensor of shape [batch_size, max_seq_len] """ data_list = torch.Tensor(data_list).long() if torch.cuda.is_available() and use_cuda: @@ -171,6 +166,7 @@ class BaseSampler(object): def __init__(self, data_set): self.data_set_length = len(data_set) + self.data = data_set def __len__(self): return self.data_set_length @@ -188,7 +184,7 @@ class SequentialSampler(BaseSampler): super(SequentialSampler, self).__init__(data_set) def __iter__(self): - return iter(range(self.data_set_length)) + return iter(self.data) class RandomSampler(BaseSampler): @@ -198,28 +194,10 @@ class RandomSampler(BaseSampler): def __init__(self, data_set): super(RandomSampler, self).__init__(data_set) + self.order = np.random.permutation(self.data_set_length) def __iter__(self): - return iter(np.random.permutation(self.data_set_length)) - - -class BucketSampler(BaseSampler): - """ - Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. - In sampling, first random choose a bucket. Then sample data from it. - The number of buckets is decided dynamically by the variance of sentence lengths. - """ - - def __init__(self, data_set): - super(BucketSampler, self).__init__(data_set) - BUCKETS = ([None] * 20) - self.length_freq = dict(Counter([len(example) for example in data_set])) - self.buckets = k_means_bucketing(data_set, BUCKETS) - - def __iter__(self): - bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] - np.random.shuffle(bucket_samples) - return iter(bucket_samples) + return iter((self.data[idx] for idx in self.order)) class Batchifier(object): @@ -235,10 +213,53 @@ class Batchifier(object): def __iter__(self): batch = [] - for idx in self.sampler: - batch.append(idx) + for example in self.sampler: + batch.append(example) if len(batch) == self.batch_size: yield batch batch = [] if 0 < len(batch) < self.batch_size and self.drop_last is False: yield batch + + +class BucketBatchifier(Batchifier): + """ + Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. + In sampling, first random choose a bucket. Then sample data from it. + The number of buckets is decided dynamically by the variance of sentence lengths. + """ + + def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): + """ + + :param data_set: three-level list, shape [num_samples, 2] + :param batch_size: int + :param num_buckets: int, number of buckets for grouping these sequences. + :param drop_last: bool, useless currently. + :param sampler: Sampler, useless currently. + """ + super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last) + buckets = ([None] * num_buckets) + self.data = data_set + self.batch_size = batch_size + self.length_freq = dict(Counter([len(example) for example in data_set])) + self.buckets = k_means_bucketing(data_set, buckets) + + def __iter__(self): + """Make a min-batch of data.""" + for _ in range(len(self.data) // self.batch_size): + bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] + np.random.shuffle(bucket_samples) + yield [self.data[idx] for idx in bucket_samples[:batch_size]] + + +if __name__ == "__main__": + import random + + data = [[[y] * random.randint(0, 50), [y]] for y in range(500)] + batch_size = 8 + iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5)) + for d in iterator: + print("\nbatch:") + for dd in d: + print(len(dd[0]), end=" ") diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index ad22eed5..b2180b10 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -1,62 +1,55 @@ -""" -To do: - 设计评判结果的各种指标。如果涉及向量,使用numpy。 - 参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics - 建议是每种metric写成一个函数 (由Tester的evaluate函数调用) - 参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 - - support numpy array and torch tensor -""" +import warnings + import numpy as np import torch -import sklearn.metrics as M -import warnings + def _conver_numpy(x): - ''' - converte input data to numpy array - ''' - if isinstance(x, np.ndarray): + """ + convert input data to numpy array + """ + if isinstance(x, np.ndarray): return x - elif isinstance(x, torch.Tensor): + elif isinstance(x, torch.Tensor): return x.numpy() - elif isinstance(x, list): + elif isinstance(x, list): return np.array(x) - raise TypeError('cannot accept obejct: {}'.format(x)) + raise TypeError('cannot accept object: {}'.format(x)) + def _check_same_len(*arrays, axis=0): - ''' + """ check if input array list has same length for one dimension - ''' + """ lens = set([x.shape[axis] for x in arrays if x is not None]) return len(lens) == 1 - + def _label_types(y): - ''' + """ determine the type "binary" "multiclass" "multiclass-multioutput" "multilabel" "unknown" - ''' + """ # never squeeze the first dimension y = np.squeeze(y, list(range(1, len(y.shape)))) shape = y.shape - if len(shape) < 1: + if len(shape) < 1: raise ValueError('cannot accept data: {}'.format(y)) if len(shape) == 1: return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y if len(shape) == 2: return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y return 'unknown', y - + def _check_data(y_true, y_pred): - ''' + """ check if y_true and y_pred is same type of data e.g both binary or multiclass - ''' + """ y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) if not _check_same_len(y_true, y_pred): raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) @@ -70,9 +63,9 @@ def _check_data(y_true, y_pred): type_set = set(['multiclass-multioutput', 'multilabel']) if type_true in type_set and type_pred in type_set: return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred - + raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) - + def _weight_sum(y, normalize=True, sample_weight=None): if normalize: @@ -119,7 +112,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): pos_list = [y_true == i for i in labels] pos_sum_list = [pos_i.sum() for pos_i in pos_list] return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ - for pos_i, sum_i in zip(pos_list, pos_sum_list)]) + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) elif y_type == 'multilabel': y_pred_right = y_true == y_pred pos = (y_true == pos_label) @@ -130,6 +123,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): raise ValueError('not support targets type {}'.format(y_type)) raise ValueError('not support for average type {}'.format(average)) + def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): y_type, y_true, y_pred = _check_data(y_true, y_pred) if average == 'binary': @@ -154,7 +148,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): pos_list = [y_true == i for i in labels] pos_sum_list = [(y_pred == i).sum() for i in labels] return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ - for pos_i, sum_i in zip(pos_list, pos_sum_list)]) + for pos_i, sum_i in zip(pos_list, pos_sum_list)]) elif y_type == 'multilabel': y_pred_right = y_true == y_pred pos = (y_true == pos_label) @@ -165,6 +159,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): raise ValueError('not support targets type {}'.format(y_type)) raise ValueError('not support for average type {}'.format(average)) + def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) @@ -178,6 +173,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): raise NotImplementedError + if __name__ == '__main__': - y = np.array([1,0,1,0,1,1]) - print(_label_types(y)) \ No newline at end of file + y = np.array([1, 0, 1, 0, 1, 1]) + print(_label_types(y)) diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index b493e3f0..fbef289a 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -1,5 +1,3 @@ -''' +""" use optimizer from Pytorch -''' - -from torch.optim import * \ No newline at end of file +""" diff --git a/fastNLP/core/inference.py b/fastNLP/core/predictor.py similarity index 62% rename from fastNLP/core/inference.py rename to fastNLP/core/predictor.py index 3937e3f4..5da00337 100644 --- a/fastNLP/core/inference.py +++ b/fastNLP/core/predictor.py @@ -7,9 +7,17 @@ from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL from fastNLP.modules import utils -def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None): - for indices in iterator: - batch_x = [data[idx] for idx in indices] +def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None): + """Batch and Pad data, only for Inference. + + :param iterator: An iterable object that returns a list of indices representing a mini-batch of samples. + :param use_cuda: bool, whether to use GPU + :param output_length: bool, whether to output the original length of the sequence before padding. (default: False) + :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) + :param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None) + :return: + """ + for batch_x in iterator: batch_x = pad(batch_x) # convert list to tensor batch_x = convert_to_torch_tensor(batch_x, use_cuda) @@ -29,11 +37,11 @@ def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_ def pad(batch, fill=0): - """ - Pad a batch of samples to maximum length. + """ Pad a mini-batch of sequence samples to maximum length of this batch. + :param batch: list of list :param fill: word index to pad, default 0. - :return: a padded batch + :return batch: a padded mini-batch """ max_length = max([len(x) for x in batch]) for idx, sample in enumerate(batch): @@ -42,13 +50,13 @@ def pad(batch, fill=0): return batch -class Inference(object): - """ - This is an interface focusing on predicting output based on trained models. +class Predictor(object): + """An interface for predicting outputs based on trained models. + It does not care about evaluations of the model, which is different from Tester. This is a high-level model wrapper to be called by FastNLP. This class does not share any operations with Trainer and Tester. - Currently, Inference does not support GPU. + Currently, Predictor does not support GPU. """ def __init__(self, pickle_path): @@ -60,11 +68,11 @@ class Inference(object): self.word2index = load_pickle(self.pickle_path, "word2id.pkl") def predict(self, network, data): - """ - Perform inference. - :param network: - :param data: two-level lists of strings - :return result: the model outputs + """Perform inference using the trained model. + + :param network: a PyTorch model + :param data: list of list of strings + :return: list of list of strings, [num_examples, tag_seq_length] """ # transform strings into indices data = self.prepare_input(data) @@ -73,9 +81,9 @@ class Inference(object): self.mode(network, test=True) self.batch_output.clear() - iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) + data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) - for batch_x in self.make_batch(iterator, data, use_cuda=False): + for batch_x in self.make_batch(data_iterator, use_cuda=False): prediction = self.data_forward(network, batch_x) @@ -90,20 +98,22 @@ class Inference(object): network.train() def data_forward(self, network, x): + """Forward through network.""" raise NotImplementedError - def make_batch(self, iterator, data, use_cuda): + def make_batch(self, iterator, use_cuda): raise NotImplementedError def prepare_input(self, data): - """ - Transform two-level list of strings into that of index. + """Transform two-level list of strings into that of index. + :param data: - [ - [word_11, word_12, ...], - [word_21, word_22, ...], - ... - ] + [ + [word_11, word_12, ...], + [word_21, word_22, ...], + ... + ] + :return data_index: list of list of int. """ assert isinstance(data, list) data_index = [] @@ -113,10 +123,11 @@ class Inference(object): return data_index def prepare_output(self, data): + """Transform list of batch outputs into strings.""" raise NotImplementedError -class SeqLabelInfer(Inference): +class SeqLabelInfer(Predictor): """ Inference on sequence labeling models. """ @@ -127,12 +138,15 @@ class SeqLabelInfer(Inference): def data_forward(self, network, inputs): """ This is only for sequence labeling with CRF decoder. - :param network: - :param inputs: - :return: Tensor + :param network: a PyTorch model + :param inputs: tuple of (x, seq_len) + x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch + after padding. + seq_len: list of int, the lengths of sequences before padding. + :return prediction: Tensor of shape [batch_size, max_len] """ if not isinstance(inputs[1], list) and isinstance(inputs[0], list): - raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + raise RuntimeError("output_length must be true for sequence modeling.") # unpack the returned value from make_batch x, seq_len = inputs[0], inputs[1] batch_size, max_len = x.size(0), x.size(1) @@ -142,14 +156,14 @@ class SeqLabelInfer(Inference): prediction = network.prediction(y, mask) return torch.Tensor(prediction) - def make_batch(self, iterator, data, use_cuda): - return make_batch(iterator, data, use_cuda, output_length=True) + def make_batch(self, iterator, use_cuda): + return make_batch(iterator, use_cuda, output_length=True) def prepare_output(self, batch_outputs): - """ - Transform list of batch outputs into strings. - :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. - :return results: 2-D list of strings + """Transform list of batch outputs into strings. + + :param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length]. + :return results: 2-D list of strings, shape [num_examples, tag_seq_length] """ results = [] for batch in batch_outputs: @@ -158,7 +172,7 @@ class SeqLabelInfer(Inference): return results -class ClassificationInfer(Inference): +class ClassificationInfer(Predictor): """ Inference on Classification models. """ @@ -171,8 +185,8 @@ class ClassificationInfer(Inference): logits = network(x) return logits - def make_batch(self, iterator, data, use_cuda): - return make_batch(iterator, data, use_cuda, output_length=False, min_len=5) + def make_batch(self, iterator, use_cuda): + return make_batch(iterator, use_cuda, output_length=False, min_len=5) def prepare_output(self, batch_outputs): """ diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 3799eed1..77592af8 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -9,7 +9,7 @@ from fastNLP.modules import utils class BaseTester(object): - """docstring for Tester""" + """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ def __init__(self, test_args): """ @@ -62,8 +62,8 @@ class BaseTester(object): step += 1 def prepare_input(self, data_path): - """ - Save the dev data once it is loaded. Can return directly next time. + """Save the dev data once it is loaded. Can return directly next time. + :param data_path: str, the path to the pickle data for dev :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). """ @@ -73,21 +73,29 @@ class BaseTester(object): return self.save_dev_data def mode(self, model, test): + """Train mode or Test mode. This is for PyTorch currently. + + :param model: a PyTorch model + :param test: bool, whether in test mode. + """ Action.mode(model, test) def data_forward(self, network, x): + """A forward pass of the model. """ raise NotImplementedError def evaluate(self, predict, truth): + """Compute evaluation metrics for the model. """ raise NotImplementedError @property def metrics(self): + """Return a list of metrics. """ raise NotImplementedError def show_matrices(self): - """ - This is called by Trainer to print evaluation on dev set. + """This is called by Trainer to print evaluation results on dev set during training. + :return print_str: str """ raise NotImplementedError @@ -112,8 +120,17 @@ class SeqLabelTester(BaseTester): self.batch_result = None def data_forward(self, network, inputs): + """This is only for sequence labeling with CRF decoder. + + :param network: a PyTorch model + :param inputs: tuple of (x, seq_len) + x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch + after padding. + seq_len: list of int, the lengths of sequences before padding. + :return y: Tensor of shape [batch_size, max_len] + """ if not isinstance(inputs, tuple): - raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + raise RuntimeError("output_length must be true for sequence modeling.") # unpack the returned value from make_batch x, seq_len = inputs[0], inputs[1] batch_size, max_len = x.size(0), x.size(1) @@ -127,6 +144,12 @@ class SeqLabelTester(BaseTester): return y def evaluate(self, predict, truth): + """Compute metrics (or loss). + + :param predict: Tensor, [batch_size, max_len, tag_size] + :param truth: Tensor, [batch_size, max_len] + :return: + """ batch_size, max_len = predict.size(0), predict.size(1) loss = self.model.loss(predict, truth, self.mask) / batch_size @@ -151,7 +174,7 @@ class SeqLabelTester(BaseTester): return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) def make_batch(self, iterator, data): - return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) + return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) class ClassificationTester(BaseTester): @@ -171,7 +194,7 @@ class ClassificationTester(BaseTester): self.iterator = None def make_batch(self, iterator, data, max_len=None): - return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len) + return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) def data_forward(self, network, x): """Forward through network.""" diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8fcdc692..77bb0757 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,5 +1,6 @@ import _pickle import os +import time from datetime import timedelta from time import time @@ -13,10 +14,11 @@ from fastNLP.core.tester import SeqLabelTester, ClassificationTester from fastNLP.modules import utils from fastNLP.saver.model_saver import ModelSaver +DEFAULT_QUEUE_SIZE = 300 + class BaseTrainer(object): - """Base trainer for all trainers. - Trainer receives a model and data, and then performs training. + """Operations to train a model, including data loading, SGD, and validation. Subclasses must implement the following abstract methods: - define_optimizer @@ -70,7 +72,7 @@ class BaseTrainer(object): else: self.model = network - data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) + data_train = self.load_train_data(self.pickle_path) # define tester over dev data if self.validate: @@ -82,33 +84,19 @@ class BaseTrainer(object): self.define_optimizer() # main training epochs - start = time() + start = time.time() n_samples = len(data_train) n_batches = n_samples // self.batch_size n_print = 1 for epoch in range(1, self.n_epochs + 1): - # turn on network training mode; prepare batch iterator + # turn on network training mode self.mode(network, test=False) - iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) - - # training iterations in one epoch - step = 0 - for batch_x, batch_y in self.make_batch(iterator, data_train): - - prediction = self.data_forward(network, batch_x) + # prepare mini-batch iterator + data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) - loss = self.get_loss(prediction, batch_y) - self.grad_backward(loss) - self.update() - - if step % n_print == 0: - end = time() - diff = timedelta(seconds=round(end - start)) - print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( - epoch, step, loss.data, diff)) - step += 1 + self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) if self.validate: validator.test(network) @@ -120,27 +108,39 @@ class BaseTrainer(object): print("[epoch {}]".format(epoch), end=" ") print(validator.show_matrices()) - def prepare_input(self, pickle_path): + def _train_step(self, data_iterator, network, **kwargs): + """Training process in one epoch.""" + step = 0 + for batch_x, batch_y in self.make_batch(data_iterator): + + prediction = self.data_forward(network, batch_x) + + loss = self.get_loss(prediction, batch_y) + self.grad_backward(loss) + self.update() + + if step % kwargs["n_print"] == 0: + end = time.time() + diff = timedelta(seconds=round(end - kwargs["start"])) + print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( + kwargs["epoch"], step, loss.data, diff)) + step += 1 + + def load_train_data(self, pickle_path): """ For task-specific processing. :param pickle_path: - :return data_train, data_dev, data_test, embedding: + :return data_train """ - names = [ - "data_train.pkl", "data_dev.pkl", - "data_test.pkl", "embedding.pkl"] - files = [] - for name in names: - file_path = os.path.join(pickle_path, name) - if os.path.exists(file_path): - with open(file_path, 'rb') as f: - data = _pickle.load(f) - else: - data = [] - files.append(data) - return tuple(files) + file_path = os.path.join(pickle_path, "data_train.pkl") + if os.path.exists(file_path): + with open(file_path, 'rb') as f: + data = _pickle.load(f) + else: + raise RuntimeError("cannot find training data {}".format(file_path)) + return data - def make_batch(self, iterator, data): + def make_batch(self, iterator): raise NotImplementedError def mode(self, network, test): @@ -219,7 +219,7 @@ class ToyTrainer(BaseTrainer): def __init__(self, training_args): super(ToyTrainer, self).__init__(training_args) - def prepare_input(self, data_path): + def load_train_data(self, data_path): data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) return data_train, data_dev, 0, 1 @@ -267,7 +267,7 @@ class SeqLabelTrainer(BaseTrainer): def data_forward(self, network, inputs): if not isinstance(inputs, tuple): - raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") + raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) # unpack the returned value from make_batch x, seq_len = inputs[0], inputs[1] @@ -303,8 +303,8 @@ class SeqLabelTrainer(BaseTrainer): else: return False - def make_batch(self, iterator, data): - return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda) + def make_batch(self, iterator): + return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) def _create_validator(self, valid_args): return SeqLabelTester(valid_args) @@ -349,8 +349,8 @@ class ClassificationTrainer(BaseTrainer): """Apply gradient.""" self.optimizer.step() - def make_batch(self, iterator, data): - return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda) + def make_batch(self, iterator): + return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) def get_acc(self, y_logit, y_true): """Compute accuracy.""" diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index e67fc63b..6339c11a 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -1,4 +1,4 @@ -from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer +from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 01bffa2b..079755e2 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -91,6 +91,9 @@ class ConfigSection(object): (key, str(type(getattr(self, key))), str(type(value)))) setattr(self, key, value) + def __contains__(self, item): + return item in self.__dict__.keys() + if __name__ == "__main__": config = ConfigLoader('configLoader', 'there is no data') diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index 9610ca2d..4b70dd0b 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -1,4 +1,4 @@ -from loader.base_loader import BaseLoader +from fastNLP.loader.base_loader import BaseLoader class EmbedLoader(BaseLoader): diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 78ebfb1a..442944e7 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -1,3 +1,9 @@ +from collections import defaultdict + +import numpy as np +import torch + + def mask_softmax(matrix, mask): if mask is None: result = torch.nn.functional.softmax(matrix, dim=-1) @@ -15,10 +21,6 @@ def seq_mask(seq_len, max_len): """ Codes from FudanParser. Not tested. Do not use !!! """ -from collections import defaultdict - -import numpy as np -import torch def expand_gt(gt): diff --git a/reproduction/chinese_word_seg/cws_train.py b/reproduction/chinese_word_seg/cws_train.py index 0a235be0..afb0ec7e 100644 --- a/reproduction/chinese_word_seg/cws_train.py +++ b/reproduction/chinese_word_seg/cws_train.py @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import SeqLabeling -from fastNLP.core.inference import Inference +from fastNLP.core.predictor import Predictor data_name = "pku_training.utf8" cws_data_path = "/home/zyfeng/data/pku_training.utf8" @@ -41,7 +41,7 @@ def infer(): infer_data = raw_data_loader.load_lines() # Inference interface - infer = Inference(pickle_path) + infer = Predictor(pickle_path) results = infer.predict(model, infer_data) print(results) diff --git a/test/__init__.py b/test/__init__.py index 8b137891..c7a5f082 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1 +1,3 @@ +import fastNLP +__all__ = ["fastNLP"] diff --git a/test/ner_decode.py b/test/ner_decode.py index a319a20e..5c09cbd2 100644 --- a/test/ner_decode.py +++ b/test/ner_decode.py @@ -3,7 +3,7 @@ import os import torch -from fastNLP.core.inference import SeqLabelInfer +from fastNLP.core.predictor import SeqLabelInfer from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.model_loader import ModelLoader from fastNLP.models.sequence_modeling import AdvSeqLabel diff --git a/test/seq_labeling.py b/test/seq_labeling.py index adc686df..a90dc75e 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import SeqLabeling -from fastNLP.core.inference import SeqLabelInfer +from fastNLP.core.predictor import SeqLabelInfer data_name = "people.txt" data_path = "data_for_tests/people.txt" @@ -112,5 +112,5 @@ def train_and_test(): if __name__ == "__main__": - train_and_test() - # infer() + # train_and_test() + infer() diff --git a/test/test_cws.py b/test/test_cws.py index f293aefd..74451e24 100644 --- a/test/test_cws.py +++ b/test/test_cws.py @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import SeqLabeling -from fastNLP.core.inference import Inference +from fastNLP.core.predictor import Predictor data_name = "pku_training.utf8" # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" @@ -51,7 +51,7 @@ def infer(): """ # Inference interface - infer = Inference(pickle_path) + infer = Predictor(pickle_path) results = infer.predict(model, infer_data) print(results) diff --git a/test/text_classify.py b/test/text_classify.py index 7400b1da..f8353f27 100644 --- a/test/text_classify.py +++ b/test/text_classify.py @@ -2,8 +2,10 @@ # encoding: utf-8 import os +import sys -from fastNLP.core.inference import ClassificationInfer +sys.path.append("..") +from fastNLP.core.predictor import ClassificationInfer from fastNLP.core.trainer import ClassificationTrainer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.dataset_loader import ClassDatasetLoader