- rename Inference to Predictor - rename Trainer.prepare_input to Trainer.load_train_data, load data_train.pkl only - add __contains__ method to config Section class - more code comments - more elegant make_batch & data_iterator: Samplers return batch samples instead of batch indicestags/v0.1.0
| @@ -1,5 +0,0 @@ | |||
| ''' | |||
| use optimizer from Pytorch | |||
| ''' | |||
| from torch.optim import * | |||
| @@ -10,7 +10,7 @@ import torch | |||
| class Action(object): | |||
| """ | |||
| Operations shared by Trainer, Tester, and Inference. | |||
| Operations shared by Trainer, Tester, or Inference. | |||
| This is designed for reducing replicate codes. | |||
| - make_batch: produce a min-batch of data. @staticmethod | |||
| - pad: padding method used in sequence modeling. @staticmethod | |||
| @@ -22,28 +22,24 @@ class Action(object): | |||
| super(Action, self).__init__() | |||
| @staticmethod | |||
| def make_batch(iterator, data, use_cuda, output_length=True, max_len=None): | |||
| def make_batch(iterator, use_cuda, output_length=True, max_len=None): | |||
| """Batch and Pad data. | |||
| :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| E.g. | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :param use_cuda: bool | |||
| :param output_length: whether to output the original length of the sequence before padding. | |||
| :param max_len: int, maximum sequence length | |||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
| :param use_cuda: bool, whether to use GPU | |||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: True) | |||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||
| :return | |||
| if output_length is True: | |||
| (batch_x, seq_len): tuple of two elements | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| return batch_x and batch_y, if output_length is False | |||
| if output_length is False: | |||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
| """ | |||
| for indices in iterator: | |||
| batch = [data[idx] for idx in indices] | |||
| for batch in iterator: | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| @@ -68,11 +64,11 @@ class Action(object): | |||
| @staticmethod | |||
| def pad(batch, fill=0): | |||
| """ | |||
| Pad a batch of samples to maximum length of this batch. | |||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return: a padded batch | |||
| :return batch: a padded mini-batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| @@ -95,11 +91,10 @@ class Action(object): | |||
| def convert_to_torch_tensor(data_list, use_cuda): | |||
| """ | |||
| convert lists into (cuda) Tensors | |||
| convert lists into (cuda) Tensors. | |||
| :param data_list: 2-level lists | |||
| :param use_cuda: bool | |||
| :param reqired_grad: bool | |||
| :return: PyTorch Tensor of shape [batch_size, max_seq_len] | |||
| :param use_cuda: bool, whether to use GPU or not | |||
| :return data_list: PyTorch Tensor of shape [batch_size, max_seq_len] | |||
| """ | |||
| data_list = torch.Tensor(data_list).long() | |||
| if torch.cuda.is_available() and use_cuda: | |||
| @@ -171,6 +166,7 @@ class BaseSampler(object): | |||
| def __init__(self, data_set): | |||
| self.data_set_length = len(data_set) | |||
| self.data = data_set | |||
| def __len__(self): | |||
| return self.data_set_length | |||
| @@ -188,7 +184,7 @@ class SequentialSampler(BaseSampler): | |||
| super(SequentialSampler, self).__init__(data_set) | |||
| def __iter__(self): | |||
| return iter(range(self.data_set_length)) | |||
| return iter(self.data) | |||
| class RandomSampler(BaseSampler): | |||
| @@ -198,28 +194,10 @@ class RandomSampler(BaseSampler): | |||
| def __init__(self, data_set): | |||
| super(RandomSampler, self).__init__(data_set) | |||
| self.order = np.random.permutation(self.data_set_length) | |||
| def __iter__(self): | |||
| return iter(np.random.permutation(self.data_set_length)) | |||
| class BucketSampler(BaseSampler): | |||
| """ | |||
| Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||
| In sampling, first random choose a bucket. Then sample data from it. | |||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||
| """ | |||
| def __init__(self, data_set): | |||
| super(BucketSampler, self).__init__(data_set) | |||
| BUCKETS = ([None] * 20) | |||
| self.length_freq = dict(Counter([len(example) for example in data_set])) | |||
| self.buckets = k_means_bucketing(data_set, BUCKETS) | |||
| def __iter__(self): | |||
| bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||
| np.random.shuffle(bucket_samples) | |||
| return iter(bucket_samples) | |||
| return iter((self.data[idx] for idx in self.order)) | |||
| class Batchifier(object): | |||
| @@ -235,10 +213,53 @@ class Batchifier(object): | |||
| def __iter__(self): | |||
| batch = [] | |||
| for idx in self.sampler: | |||
| batch.append(idx) | |||
| for example in self.sampler: | |||
| batch.append(example) | |||
| if len(batch) == self.batch_size: | |||
| yield batch | |||
| batch = [] | |||
| if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||
| yield batch | |||
| class BucketBatchifier(Batchifier): | |||
| """ | |||
| Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||
| In sampling, first random choose a bucket. Then sample data from it. | |||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||
| """ | |||
| def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | |||
| """ | |||
| :param data_set: three-level list, shape [num_samples, 2] | |||
| :param batch_size: int | |||
| :param num_buckets: int, number of buckets for grouping these sequences. | |||
| :param drop_last: bool, useless currently. | |||
| :param sampler: Sampler, useless currently. | |||
| """ | |||
| super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last) | |||
| buckets = ([None] * num_buckets) | |||
| self.data = data_set | |||
| self.batch_size = batch_size | |||
| self.length_freq = dict(Counter([len(example) for example in data_set])) | |||
| self.buckets = k_means_bucketing(data_set, buckets) | |||
| def __iter__(self): | |||
| """Make a min-batch of data.""" | |||
| for _ in range(len(self.data) // self.batch_size): | |||
| bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||
| np.random.shuffle(bucket_samples) | |||
| yield [self.data[idx] for idx in bucket_samples[:batch_size]] | |||
| if __name__ == "__main__": | |||
| import random | |||
| data = [[[y] * random.randint(0, 50), [y]] for y in range(500)] | |||
| batch_size = 8 | |||
| iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5)) | |||
| for d in iterator: | |||
| print("\nbatch:") | |||
| for dd in d: | |||
| print(len(dd[0]), end=" ") | |||
| @@ -1,62 +1,55 @@ | |||
| """ | |||
| To do: | |||
| 设计评判结果的各种指标。如果涉及向量,使用numpy。 | |||
| 参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics | |||
| 建议是每种metric写成一个函数 (由Tester的evaluate函数调用) | |||
| 参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 | |||
| support numpy array and torch tensor | |||
| """ | |||
| import warnings | |||
| import numpy as np | |||
| import torch | |||
| import sklearn.metrics as M | |||
| import warnings | |||
| def _conver_numpy(x): | |||
| ''' | |||
| converte input data to numpy array | |||
| ''' | |||
| if isinstance(x, np.ndarray): | |||
| """ | |||
| convert input data to numpy array | |||
| """ | |||
| if isinstance(x, np.ndarray): | |||
| return x | |||
| elif isinstance(x, torch.Tensor): | |||
| elif isinstance(x, torch.Tensor): | |||
| return x.numpy() | |||
| elif isinstance(x, list): | |||
| elif isinstance(x, list): | |||
| return np.array(x) | |||
| raise TypeError('cannot accept obejct: {}'.format(x)) | |||
| raise TypeError('cannot accept object: {}'.format(x)) | |||
| def _check_same_len(*arrays, axis=0): | |||
| ''' | |||
| """ | |||
| check if input array list has same length for one dimension | |||
| ''' | |||
| """ | |||
| lens = set([x.shape[axis] for x in arrays if x is not None]) | |||
| return len(lens) == 1 | |||
| def _label_types(y): | |||
| ''' | |||
| """ | |||
| determine the type | |||
| "binary" | |||
| "multiclass" | |||
| "multiclass-multioutput" | |||
| "multilabel" | |||
| "unknown" | |||
| ''' | |||
| """ | |||
| # never squeeze the first dimension | |||
| y = np.squeeze(y, list(range(1, len(y.shape)))) | |||
| shape = y.shape | |||
| if len(shape) < 1: | |||
| if len(shape) < 1: | |||
| raise ValueError('cannot accept data: {}'.format(y)) | |||
| if len(shape) == 1: | |||
| return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | |||
| if len(shape) == 2: | |||
| return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | |||
| return 'unknown', y | |||
| def _check_data(y_true, y_pred): | |||
| ''' | |||
| """ | |||
| check if y_true and y_pred is same type of data e.g both binary or multiclass | |||
| ''' | |||
| """ | |||
| y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | |||
| if not _check_same_len(y_true, y_pred): | |||
| raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | |||
| @@ -70,9 +63,9 @@ def _check_data(y_true, y_pred): | |||
| type_set = set(['multiclass-multioutput', 'multilabel']) | |||
| if type_true in type_set and type_pred in type_set: | |||
| return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | |||
| raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | |||
| def _weight_sum(y, normalize=True, sample_weight=None): | |||
| if normalize: | |||
| @@ -119,7 +112,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| pos_list = [y_true == i for i in labels] | |||
| pos_sum_list = [pos_i.sum() for pos_i in pos_list] | |||
| return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||
| for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
| for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
| elif y_type == 'multilabel': | |||
| y_pred_right = y_true == y_pred | |||
| pos = (y_true == pos_label) | |||
| @@ -130,6 +123,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| raise ValueError('not support targets type {}'.format(y_type)) | |||
| raise ValueError('not support for average type {}'.format(average)) | |||
| def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| y_type, y_true, y_pred = _check_data(y_true, y_pred) | |||
| if average == 'binary': | |||
| @@ -154,7 +148,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| pos_list = [y_true == i for i in labels] | |||
| pos_sum_list = [(y_pred == i).sum() for i in labels] | |||
| return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | |||
| for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
| for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||
| elif y_type == 'multilabel': | |||
| y_pred_right = y_true == y_pred | |||
| pos = (y_true == pos_label) | |||
| @@ -165,6 +159,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| raise ValueError('not support targets type {}'.format(y_type)) | |||
| raise ValueError('not support for average type {}'.format(average)) | |||
| def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||
| recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | |||
| @@ -178,6 +173,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||
| def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | |||
| raise NotImplementedError | |||
| if __name__ == '__main__': | |||
| y = np.array([1,0,1,0,1,1]) | |||
| print(_label_types(y)) | |||
| y = np.array([1, 0, 1, 0, 1, 1]) | |||
| print(_label_types(y)) | |||
| @@ -1,5 +1,3 @@ | |||
| ''' | |||
| """ | |||
| use optimizer from Pytorch | |||
| ''' | |||
| from torch.optim import * | |||
| """ | |||
| @@ -7,9 +7,17 @@ from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||
| from fastNLP.modules import utils | |||
| def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None): | |||
| for indices in iterator: | |||
| batch_x = [data[idx] for idx in indices] | |||
| def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None): | |||
| """Batch and Pad data, only for Inference. | |||
| :param iterator: An iterable object that returns a list of indices representing a mini-batch of samples. | |||
| :param use_cuda: bool, whether to use GPU | |||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: False) | |||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||
| :param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None) | |||
| :return: | |||
| """ | |||
| for batch_x in iterator: | |||
| batch_x = pad(batch_x) | |||
| # convert list to tensor | |||
| batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||
| @@ -29,11 +37,11 @@ def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_ | |||
| def pad(batch, fill=0): | |||
| """ | |||
| Pad a batch of samples to maximum length. | |||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||
| :param batch: list of list | |||
| :param fill: word index to pad, default 0. | |||
| :return: a padded batch | |||
| :return batch: a padded mini-batch | |||
| """ | |||
| max_length = max([len(x) for x in batch]) | |||
| for idx, sample in enumerate(batch): | |||
| @@ -42,13 +50,13 @@ def pad(batch, fill=0): | |||
| return batch | |||
| class Inference(object): | |||
| """ | |||
| This is an interface focusing on predicting output based on trained models. | |||
| class Predictor(object): | |||
| """An interface for predicting outputs based on trained models. | |||
| It does not care about evaluations of the model, which is different from Tester. | |||
| This is a high-level model wrapper to be called by FastNLP. | |||
| This class does not share any operations with Trainer and Tester. | |||
| Currently, Inference does not support GPU. | |||
| Currently, Predictor does not support GPU. | |||
| """ | |||
| def __init__(self, pickle_path): | |||
| @@ -60,11 +68,11 @@ class Inference(object): | |||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||
| def predict(self, network, data): | |||
| """ | |||
| Perform inference. | |||
| :param network: | |||
| :param data: two-level lists of strings | |||
| :return result: the model outputs | |||
| """Perform inference using the trained model. | |||
| :param network: a PyTorch model | |||
| :param data: list of list of strings | |||
| :return: list of list of strings, [num_examples, tag_seq_length] | |||
| """ | |||
| # transform strings into indices | |||
| data = self.prepare_input(data) | |||
| @@ -73,9 +81,9 @@ class Inference(object): | |||
| self.mode(network, test=True) | |||
| self.batch_output.clear() | |||
| iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||
| data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||
| for batch_x in self.make_batch(iterator, data, use_cuda=False): | |||
| for batch_x in self.make_batch(data_iterator, use_cuda=False): | |||
| prediction = self.data_forward(network, batch_x) | |||
| @@ -90,20 +98,22 @@ class Inference(object): | |||
| network.train() | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| raise NotImplementedError | |||
| def make_batch(self, iterator, data, use_cuda): | |||
| def make_batch(self, iterator, use_cuda): | |||
| raise NotImplementedError | |||
| def prepare_input(self, data): | |||
| """ | |||
| Transform two-level list of strings into that of index. | |||
| """Transform two-level list of strings into that of index. | |||
| :param data: | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| :return data_index: list of list of int. | |||
| """ | |||
| assert isinstance(data, list) | |||
| data_index = [] | |||
| @@ -113,10 +123,11 @@ class Inference(object): | |||
| return data_index | |||
| def prepare_output(self, data): | |||
| """Transform list of batch outputs into strings.""" | |||
| raise NotImplementedError | |||
| class SeqLabelInfer(Inference): | |||
| class SeqLabelInfer(Predictor): | |||
| """ | |||
| Inference on sequence labeling models. | |||
| """ | |||
| @@ -127,12 +138,15 @@ class SeqLabelInfer(Inference): | |||
| def data_forward(self, network, inputs): | |||
| """ | |||
| This is only for sequence labeling with CRF decoder. | |||
| :param network: | |||
| :param inputs: | |||
| :return: Tensor | |||
| :param network: a PyTorch model | |||
| :param inputs: tuple of (x, seq_len) | |||
| x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||
| after padding. | |||
| seq_len: list of int, the lengths of sequences before padding. | |||
| :return prediction: Tensor of shape [batch_size, max_len] | |||
| """ | |||
| if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | |||
| raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
| raise RuntimeError("output_length must be true for sequence modeling.") | |||
| # unpack the returned value from make_batch | |||
| x, seq_len = inputs[0], inputs[1] | |||
| batch_size, max_len = x.size(0), x.size(1) | |||
| @@ -142,14 +156,14 @@ class SeqLabelInfer(Inference): | |||
| prediction = network.prediction(y, mask) | |||
| return torch.Tensor(prediction) | |||
| def make_batch(self, iterator, data, use_cuda): | |||
| return make_batch(iterator, data, use_cuda, output_length=True) | |||
| def make_batch(self, iterator, use_cuda): | |||
| return make_batch(iterator, use_cuda, output_length=True) | |||
| def prepare_output(self, batch_outputs): | |||
| """ | |||
| Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||
| :return results: 2-D list of strings | |||
| """Transform list of batch outputs into strings. | |||
| :param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length]. | |||
| :return results: 2-D list of strings, shape [num_examples, tag_seq_length] | |||
| """ | |||
| results = [] | |||
| for batch in batch_outputs: | |||
| @@ -158,7 +172,7 @@ class SeqLabelInfer(Inference): | |||
| return results | |||
| class ClassificationInfer(Inference): | |||
| class ClassificationInfer(Predictor): | |||
| """ | |||
| Inference on Classification models. | |||
| """ | |||
| @@ -171,8 +185,8 @@ class ClassificationInfer(Inference): | |||
| logits = network(x) | |||
| return logits | |||
| def make_batch(self, iterator, data, use_cuda): | |||
| return make_batch(iterator, data, use_cuda, output_length=False, min_len=5) | |||
| def make_batch(self, iterator, use_cuda): | |||
| return make_batch(iterator, use_cuda, output_length=False, min_len=5) | |||
| def prepare_output(self, batch_outputs): | |||
| """ | |||
| @@ -9,7 +9,7 @@ from fastNLP.modules import utils | |||
| class BaseTester(object): | |||
| """docstring for Tester""" | |||
| """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
| def __init__(self, test_args): | |||
| """ | |||
| @@ -62,8 +62,8 @@ class BaseTester(object): | |||
| step += 1 | |||
| def prepare_input(self, data_path): | |||
| """ | |||
| Save the dev data once it is loaded. Can return directly next time. | |||
| """Save the dev data once it is loaded. Can return directly next time. | |||
| :param data_path: str, the path to the pickle data for dev | |||
| :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||
| """ | |||
| @@ -73,21 +73,29 @@ class BaseTester(object): | |||
| return self.save_dev_data | |||
| def mode(self, model, test): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: a PyTorch model | |||
| :param test: bool, whether in test mode. | |||
| """ | |||
| Action.mode(model, test) | |||
| def data_forward(self, network, x): | |||
| """A forward pass of the model. """ | |||
| raise NotImplementedError | |||
| def evaluate(self, predict, truth): | |||
| """Compute evaluation metrics for the model. """ | |||
| raise NotImplementedError | |||
| @property | |||
| def metrics(self): | |||
| """Return a list of metrics. """ | |||
| raise NotImplementedError | |||
| def show_matrices(self): | |||
| """ | |||
| This is called by Trainer to print evaluation on dev set. | |||
| """This is called by Trainer to print evaluation results on dev set during training. | |||
| :return print_str: str | |||
| """ | |||
| raise NotImplementedError | |||
| @@ -112,8 +120,17 @@ class SeqLabelTester(BaseTester): | |||
| self.batch_result = None | |||
| def data_forward(self, network, inputs): | |||
| """This is only for sequence labeling with CRF decoder. | |||
| :param network: a PyTorch model | |||
| :param inputs: tuple of (x, seq_len) | |||
| x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||
| after padding. | |||
| seq_len: list of int, the lengths of sequences before padding. | |||
| :return y: Tensor of shape [batch_size, max_len] | |||
| """ | |||
| if not isinstance(inputs, tuple): | |||
| raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
| raise RuntimeError("output_length must be true for sequence modeling.") | |||
| # unpack the returned value from make_batch | |||
| x, seq_len = inputs[0], inputs[1] | |||
| batch_size, max_len = x.size(0), x.size(1) | |||
| @@ -127,6 +144,12 @@ class SeqLabelTester(BaseTester): | |||
| return y | |||
| def evaluate(self, predict, truth): | |||
| """Compute metrics (or loss). | |||
| :param predict: Tensor, [batch_size, max_len, tag_size] | |||
| :param truth: Tensor, [batch_size, max_len] | |||
| :return: | |||
| """ | |||
| batch_size, max_len = predict.size(0), predict.size(1) | |||
| loss = self.model.loss(predict, truth, self.mask) / batch_size | |||
| @@ -151,7 +174,7 @@ class SeqLabelTester(BaseTester): | |||
| return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
| def make_batch(self, iterator, data): | |||
| return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) | |||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | |||
| class ClassificationTester(BaseTester): | |||
| @@ -171,7 +194,7 @@ class ClassificationTester(BaseTester): | |||
| self.iterator = None | |||
| def make_batch(self, iterator, data, max_len=None): | |||
| return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len) | |||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| @@ -1,5 +1,6 @@ | |||
| import _pickle | |||
| import os | |||
| import time | |||
| from datetime import timedelta | |||
| from time import time | |||
| @@ -13,10 +14,11 @@ from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||
| from fastNLP.modules import utils | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| DEFAULT_QUEUE_SIZE = 300 | |||
| class BaseTrainer(object): | |||
| """Base trainer for all trainers. | |||
| Trainer receives a model and data, and then performs training. | |||
| """Operations to train a model, including data loading, SGD, and validation. | |||
| Subclasses must implement the following abstract methods: | |||
| - define_optimizer | |||
| @@ -70,7 +72,7 @@ class BaseTrainer(object): | |||
| else: | |||
| self.model = network | |||
| data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||
| data_train = self.load_train_data(self.pickle_path) | |||
| # define tester over dev data | |||
| if self.validate: | |||
| @@ -82,33 +84,19 @@ class BaseTrainer(object): | |||
| self.define_optimizer() | |||
| # main training epochs | |||
| start = time() | |||
| start = time.time() | |||
| n_samples = len(data_train) | |||
| n_batches = n_samples // self.batch_size | |||
| n_print = 1 | |||
| for epoch in range(1, self.n_epochs + 1): | |||
| # turn on network training mode; prepare batch iterator | |||
| # turn on network training mode | |||
| self.mode(network, test=False) | |||
| iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||
| # training iterations in one epoch | |||
| step = 0 | |||
| for batch_x, batch_y in self.make_batch(iterator, data_train): | |||
| prediction = self.data_forward(network, batch_x) | |||
| # prepare mini-batch iterator | |||
| data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||
| loss = self.get_loss(prediction, batch_y) | |||
| self.grad_backward(loss) | |||
| self.update() | |||
| if step % n_print == 0: | |||
| end = time() | |||
| diff = timedelta(seconds=round(end - start)) | |||
| print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||
| epoch, step, loss.data, diff)) | |||
| step += 1 | |||
| self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | |||
| if self.validate: | |||
| validator.test(network) | |||
| @@ -120,27 +108,39 @@ class BaseTrainer(object): | |||
| print("[epoch {}]".format(epoch), end=" ") | |||
| print(validator.show_matrices()) | |||
| def prepare_input(self, pickle_path): | |||
| def _train_step(self, data_iterator, network, **kwargs): | |||
| """Training process in one epoch.""" | |||
| step = 0 | |||
| for batch_x, batch_y in self.make_batch(data_iterator): | |||
| prediction = self.data_forward(network, batch_x) | |||
| loss = self.get_loss(prediction, batch_y) | |||
| self.grad_backward(loss) | |||
| self.update() | |||
| if step % kwargs["n_print"] == 0: | |||
| end = time.time() | |||
| diff = timedelta(seconds=round(end - kwargs["start"])) | |||
| print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||
| kwargs["epoch"], step, loss.data, diff)) | |||
| step += 1 | |||
| def load_train_data(self, pickle_path): | |||
| """ | |||
| For task-specific processing. | |||
| :param pickle_path: | |||
| :return data_train, data_dev, data_test, embedding: | |||
| :return data_train | |||
| """ | |||
| names = [ | |||
| "data_train.pkl", "data_dev.pkl", | |||
| "data_test.pkl", "embedding.pkl"] | |||
| files = [] | |||
| for name in names: | |||
| file_path = os.path.join(pickle_path, name) | |||
| if os.path.exists(file_path): | |||
| with open(file_path, 'rb') as f: | |||
| data = _pickle.load(f) | |||
| else: | |||
| data = [] | |||
| files.append(data) | |||
| return tuple(files) | |||
| file_path = os.path.join(pickle_path, "data_train.pkl") | |||
| if os.path.exists(file_path): | |||
| with open(file_path, 'rb') as f: | |||
| data = _pickle.load(f) | |||
| else: | |||
| raise RuntimeError("cannot find training data {}".format(file_path)) | |||
| return data | |||
| def make_batch(self, iterator, data): | |||
| def make_batch(self, iterator): | |||
| raise NotImplementedError | |||
| def mode(self, network, test): | |||
| @@ -219,7 +219,7 @@ class ToyTrainer(BaseTrainer): | |||
| def __init__(self, training_args): | |||
| super(ToyTrainer, self).__init__(training_args) | |||
| def prepare_input(self, data_path): | |||
| def load_train_data(self, data_path): | |||
| data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| return data_train, data_dev, 0, 1 | |||
| @@ -267,7 +267,7 @@ class SeqLabelTrainer(BaseTrainer): | |||
| def data_forward(self, network, inputs): | |||
| if not isinstance(inputs, tuple): | |||
| raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
| raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | |||
| # unpack the returned value from make_batch | |||
| x, seq_len = inputs[0], inputs[1] | |||
| @@ -303,8 +303,8 @@ class SeqLabelTrainer(BaseTrainer): | |||
| else: | |||
| return False | |||
| def make_batch(self, iterator, data): | |||
| return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda) | |||
| def make_batch(self, iterator): | |||
| return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | |||
| def _create_validator(self, valid_args): | |||
| return SeqLabelTester(valid_args) | |||
| @@ -349,8 +349,8 @@ class ClassificationTrainer(BaseTrainer): | |||
| """Apply gradient.""" | |||
| self.optimizer.step() | |||
| def make_batch(self, iterator, data): | |||
| return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda) | |||
| def make_batch(self, iterator): | |||
| return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | |||
| def get_acc(self, y_logit, y_true): | |||
| """Compute accuracy.""" | |||
| @@ -1,4 +1,4 @@ | |||
| from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer | |||
| from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| @@ -91,6 +91,9 @@ class ConfigSection(object): | |||
| (key, str(type(getattr(self, key))), str(type(value)))) | |||
| setattr(self, key, value) | |||
| def __contains__(self, item): | |||
| return item in self.__dict__.keys() | |||
| if __name__ == "__main__": | |||
| config = ConfigLoader('configLoader', 'there is no data') | |||
| @@ -1,4 +1,4 @@ | |||
| from loader.base_loader import BaseLoader | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| class EmbedLoader(BaseLoader): | |||
| @@ -1,3 +1,9 @@ | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import torch | |||
| def mask_softmax(matrix, mask): | |||
| if mask is None: | |||
| result = torch.nn.functional.softmax(matrix, dim=-1) | |||
| @@ -15,10 +21,6 @@ def seq_mask(seq_len, max_len): | |||
| """ | |||
| Codes from FudanParser. Not tested. Do not use !!! | |||
| """ | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import torch | |||
| def expand_gt(gt): | |||
| @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.inference import Inference | |||
| from fastNLP.core.predictor import Predictor | |||
| data_name = "pku_training.utf8" | |||
| cws_data_path = "/home/zyfeng/data/pku_training.utf8" | |||
| @@ -41,7 +41,7 @@ def infer(): | |||
| infer_data = raw_data_loader.load_lines() | |||
| # Inference interface | |||
| infer = Inference(pickle_path) | |||
| infer = Predictor(pickle_path) | |||
| results = infer.predict(model, infer_data) | |||
| print(results) | |||
| @@ -1 +1,3 @@ | |||
| import fastNLP | |||
| __all__ = ["fastNLP"] | |||
| @@ -3,7 +3,7 @@ import os | |||
| import torch | |||
| from fastNLP.core.inference import SeqLabelInfer | |||
| from fastNLP.core.predictor import SeqLabelInfer | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.inference import SeqLabelInfer | |||
| from fastNLP.core.predictor import SeqLabelInfer | |||
| data_name = "people.txt" | |||
| data_path = "data_for_tests/people.txt" | |||
| @@ -112,5 +112,5 @@ def train_and_test(): | |||
| if __name__ == "__main__": | |||
| train_and_test() | |||
| # infer() | |||
| # train_and_test() | |||
| infer() | |||
| @@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.inference import Inference | |||
| from fastNLP.core.predictor import Predictor | |||
| data_name = "pku_training.utf8" | |||
| # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||
| @@ -51,7 +51,7 @@ def infer(): | |||
| """ | |||
| # Inference interface | |||
| infer = Inference(pickle_path) | |||
| infer = Predictor(pickle_path) | |||
| results = infer.predict(model, infer_data) | |||
| print(results) | |||
| @@ -2,8 +2,10 @@ | |||
| # encoding: utf-8 | |||
| import os | |||
| import sys | |||
| from fastNLP.core.inference import ClassificationInfer | |||
| sys.path.append("..") | |||
| from fastNLP.core.predictor import ClassificationInfer | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||