- rename Inference to Predictor - rename Trainer.prepare_input to Trainer.load_train_data, load data_train.pkl only - add __contains__ method to config Section class - more code comments - more elegant make_batch & data_iterator: Samplers return batch samples instead of batch indicestags/v0.1.0
@@ -1,5 +0,0 @@ | |||||
''' | |||||
use optimizer from Pytorch | |||||
''' | |||||
from torch.optim import * |
@@ -10,7 +10,7 @@ import torch | |||||
class Action(object): | class Action(object): | ||||
""" | """ | ||||
Operations shared by Trainer, Tester, and Inference. | |||||
Operations shared by Trainer, Tester, or Inference. | |||||
This is designed for reducing replicate codes. | This is designed for reducing replicate codes. | ||||
- make_batch: produce a min-batch of data. @staticmethod | - make_batch: produce a min-batch of data. @staticmethod | ||||
- pad: padding method used in sequence modeling. @staticmethod | - pad: padding method used in sequence modeling. @staticmethod | ||||
@@ -22,28 +22,24 @@ class Action(object): | |||||
super(Action, self).__init__() | super(Action, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None): | |||||
def make_batch(iterator, use_cuda, output_length=True, max_len=None): | |||||
"""Batch and Pad data. | """Batch and Pad data. | ||||
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | ||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:param use_cuda: bool | |||||
:param output_length: whether to output the original length of the sequence before padding. | |||||
:param max_len: int, maximum sequence length | |||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
:param use_cuda: bool, whether to use GPU | |||||
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True) | |||||
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||||
:return | |||||
if output_length is True: | |||||
(batch_x, seq_len): tuple of two elements | |||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | ||||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | seq_len: list. The length of the pre-padded sequence, if output_length is True. | ||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
return batch_x and batch_y, if output_length is False | |||||
if output_length is False: | |||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
""" | """ | ||||
for indices in iterator: | |||||
batch = [data[idx] for idx in indices] | |||||
for batch in iterator: | |||||
batch_x = [sample[0] for sample in batch] | batch_x = [sample[0] for sample in batch] | ||||
batch_y = [sample[1] for sample in batch] | batch_y = [sample[1] for sample in batch] | ||||
@@ -68,11 +64,11 @@ class Action(object): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
""" | |||||
Pad a batch of samples to maximum length of this batch. | |||||
""" Pad a mini-batch of sequence samples to maximum length of this batch. | |||||
:param batch: list of list | :param batch: list of list | ||||
:param fill: word index to pad, default 0. | :param fill: word index to pad, default 0. | ||||
:return: a padded batch | |||||
:return batch: a padded mini-batch | |||||
""" | """ | ||||
max_length = max([len(x) for x in batch]) | max_length = max([len(x) for x in batch]) | ||||
for idx, sample in enumerate(batch): | for idx, sample in enumerate(batch): | ||||
@@ -95,11 +91,10 @@ class Action(object): | |||||
def convert_to_torch_tensor(data_list, use_cuda): | def convert_to_torch_tensor(data_list, use_cuda): | ||||
""" | """ | ||||
convert lists into (cuda) Tensors | |||||
convert lists into (cuda) Tensors. | |||||
:param data_list: 2-level lists | :param data_list: 2-level lists | ||||
:param use_cuda: bool | |||||
:param reqired_grad: bool | |||||
:return: PyTorch Tensor of shape [batch_size, max_seq_len] | |||||
:param use_cuda: bool, whether to use GPU or not | |||||
:return data_list: PyTorch Tensor of shape [batch_size, max_seq_len] | |||||
""" | """ | ||||
data_list = torch.Tensor(data_list).long() | data_list = torch.Tensor(data_list).long() | ||||
if torch.cuda.is_available() and use_cuda: | if torch.cuda.is_available() and use_cuda: | ||||
@@ -171,6 +166,7 @@ class BaseSampler(object): | |||||
def __init__(self, data_set): | def __init__(self, data_set): | ||||
self.data_set_length = len(data_set) | self.data_set_length = len(data_set) | ||||
self.data = data_set | |||||
def __len__(self): | def __len__(self): | ||||
return self.data_set_length | return self.data_set_length | ||||
@@ -188,7 +184,7 @@ class SequentialSampler(BaseSampler): | |||||
super(SequentialSampler, self).__init__(data_set) | super(SequentialSampler, self).__init__(data_set) | ||||
def __iter__(self): | def __iter__(self): | ||||
return iter(range(self.data_set_length)) | |||||
return iter(self.data) | |||||
class RandomSampler(BaseSampler): | class RandomSampler(BaseSampler): | ||||
@@ -198,28 +194,10 @@ class RandomSampler(BaseSampler): | |||||
def __init__(self, data_set): | def __init__(self, data_set): | ||||
super(RandomSampler, self).__init__(data_set) | super(RandomSampler, self).__init__(data_set) | ||||
self.order = np.random.permutation(self.data_set_length) | |||||
def __iter__(self): | def __iter__(self): | ||||
return iter(np.random.permutation(self.data_set_length)) | |||||
class BucketSampler(BaseSampler): | |||||
""" | |||||
Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||||
In sampling, first random choose a bucket. Then sample data from it. | |||||
The number of buckets is decided dynamically by the variance of sentence lengths. | |||||
""" | |||||
def __init__(self, data_set): | |||||
super(BucketSampler, self).__init__(data_set) | |||||
BUCKETS = ([None] * 20) | |||||
self.length_freq = dict(Counter([len(example) for example in data_set])) | |||||
self.buckets = k_means_bucketing(data_set, BUCKETS) | |||||
def __iter__(self): | |||||
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||||
np.random.shuffle(bucket_samples) | |||||
return iter(bucket_samples) | |||||
return iter((self.data[idx] for idx in self.order)) | |||||
class Batchifier(object): | class Batchifier(object): | ||||
@@ -235,10 +213,53 @@ class Batchifier(object): | |||||
def __iter__(self): | def __iter__(self): | ||||
batch = [] | batch = [] | ||||
for idx in self.sampler: | |||||
batch.append(idx) | |||||
for example in self.sampler: | |||||
batch.append(example) | |||||
if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
yield batch | yield batch | ||||
batch = [] | batch = [] | ||||
if 0 < len(batch) < self.batch_size and self.drop_last is False: | if 0 < len(batch) < self.batch_size and self.drop_last is False: | ||||
yield batch | yield batch | ||||
class BucketBatchifier(Batchifier): | |||||
""" | |||||
Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||||
In sampling, first random choose a bucket. Then sample data from it. | |||||
The number of buckets is decided dynamically by the variance of sentence lengths. | |||||
""" | |||||
def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | |||||
""" | |||||
:param data_set: three-level list, shape [num_samples, 2] | |||||
:param batch_size: int | |||||
:param num_buckets: int, number of buckets for grouping these sequences. | |||||
:param drop_last: bool, useless currently. | |||||
:param sampler: Sampler, useless currently. | |||||
""" | |||||
super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last) | |||||
buckets = ([None] * num_buckets) | |||||
self.data = data_set | |||||
self.batch_size = batch_size | |||||
self.length_freq = dict(Counter([len(example) for example in data_set])) | |||||
self.buckets = k_means_bucketing(data_set, buckets) | |||||
def __iter__(self): | |||||
"""Make a min-batch of data.""" | |||||
for _ in range(len(self.data) // self.batch_size): | |||||
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||||
np.random.shuffle(bucket_samples) | |||||
yield [self.data[idx] for idx in bucket_samples[:batch_size]] | |||||
if __name__ == "__main__": | |||||
import random | |||||
data = [[[y] * random.randint(0, 50), [y]] for y in range(500)] | |||||
batch_size = 8 | |||||
iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5)) | |||||
for d in iterator: | |||||
print("\nbatch:") | |||||
for dd in d: | |||||
print(len(dd[0]), end=" ") |
@@ -1,62 +1,55 @@ | |||||
""" | |||||
To do: | |||||
设计评判结果的各种指标。如果涉及向量,使用numpy。 | |||||
参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics | |||||
建议是每种metric写成一个函数 (由Tester的evaluate函数调用) | |||||
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置 | |||||
support numpy array and torch tensor | |||||
""" | |||||
import warnings | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import sklearn.metrics as M | |||||
import warnings | |||||
def _conver_numpy(x): | def _conver_numpy(x): | ||||
''' | |||||
converte input data to numpy array | |||||
''' | |||||
if isinstance(x, np.ndarray): | |||||
""" | |||||
convert input data to numpy array | |||||
""" | |||||
if isinstance(x, np.ndarray): | |||||
return x | return x | ||||
elif isinstance(x, torch.Tensor): | |||||
elif isinstance(x, torch.Tensor): | |||||
return x.numpy() | return x.numpy() | ||||
elif isinstance(x, list): | |||||
elif isinstance(x, list): | |||||
return np.array(x) | return np.array(x) | ||||
raise TypeError('cannot accept obejct: {}'.format(x)) | |||||
raise TypeError('cannot accept object: {}'.format(x)) | |||||
def _check_same_len(*arrays, axis=0): | def _check_same_len(*arrays, axis=0): | ||||
''' | |||||
""" | |||||
check if input array list has same length for one dimension | check if input array list has same length for one dimension | ||||
''' | |||||
""" | |||||
lens = set([x.shape[axis] for x in arrays if x is not None]) | lens = set([x.shape[axis] for x in arrays if x is not None]) | ||||
return len(lens) == 1 | return len(lens) == 1 | ||||
def _label_types(y): | def _label_types(y): | ||||
''' | |||||
""" | |||||
determine the type | determine the type | ||||
"binary" | "binary" | ||||
"multiclass" | "multiclass" | ||||
"multiclass-multioutput" | "multiclass-multioutput" | ||||
"multilabel" | "multilabel" | ||||
"unknown" | "unknown" | ||||
''' | |||||
""" | |||||
# never squeeze the first dimension | # never squeeze the first dimension | ||||
y = np.squeeze(y, list(range(1, len(y.shape)))) | y = np.squeeze(y, list(range(1, len(y.shape)))) | ||||
shape = y.shape | shape = y.shape | ||||
if len(shape) < 1: | |||||
if len(shape) < 1: | |||||
raise ValueError('cannot accept data: {}'.format(y)) | raise ValueError('cannot accept data: {}'.format(y)) | ||||
if len(shape) == 1: | if len(shape) == 1: | ||||
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y | ||||
if len(shape) == 2: | if len(shape) == 2: | ||||
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y | ||||
return 'unknown', y | return 'unknown', y | ||||
def _check_data(y_true, y_pred): | def _check_data(y_true, y_pred): | ||||
''' | |||||
""" | |||||
check if y_true and y_pred is same type of data e.g both binary or multiclass | check if y_true and y_pred is same type of data e.g both binary or multiclass | ||||
''' | |||||
""" | |||||
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred) | ||||
if not _check_same_len(y_true, y_pred): | if not _check_same_len(y_true, y_pred): | ||||
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred)) | ||||
@@ -70,9 +63,9 @@ def _check_data(y_true, y_pred): | |||||
type_set = set(['multiclass-multioutput', 'multilabel']) | type_set = set(['multiclass-multioutput', 'multilabel']) | ||||
if type_true in type_set and type_pred in type_set: | if type_true in type_set and type_pred in type_set: | ||||
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred | ||||
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred)) | ||||
def _weight_sum(y, normalize=True, sample_weight=None): | def _weight_sum(y, normalize=True, sample_weight=None): | ||||
if normalize: | if normalize: | ||||
@@ -119,7 +112,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
pos_list = [y_true == i for i in labels] | pos_list = [y_true == i for i in labels] | ||||
pos_sum_list = [pos_i.sum() for pos_i in pos_list] | pos_sum_list = [pos_i.sum() for pos_i in pos_list] | ||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | ||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | elif y_type == 'multilabel': | ||||
y_pred_right = y_true == y_pred | y_pred_right = y_true == y_pred | ||||
pos = (y_true == pos_label) | pos = (y_true == pos_label) | ||||
@@ -130,6 +123,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
raise ValueError('not support targets type {}'.format(y_type)) | raise ValueError('not support targets type {}'.format(y_type)) | ||||
raise ValueError('not support for average type {}'.format(average)) | raise ValueError('not support for average type {}'.format(average)) | ||||
def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | ||||
y_type, y_true, y_pred = _check_data(y_true, y_pred) | y_type, y_true, y_pred = _check_data(y_true, y_pred) | ||||
if average == 'binary': | if average == 'binary': | ||||
@@ -154,7 +148,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
pos_list = [y_true == i for i in labels] | pos_list = [y_true == i for i in labels] | ||||
pos_sum_list = [(y_pred == i).sum() for i in labels] | pos_sum_list = [(y_pred == i).sum() for i in labels] | ||||
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \ | ||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
for pos_i, sum_i in zip(pos_list, pos_sum_list)]) | |||||
elif y_type == 'multilabel': | elif y_type == 'multilabel': | ||||
y_pred_right = y_true == y_pred | y_pred_right = y_true == y_pred | ||||
pos = (y_true == pos_label) | pos = (y_true == pos_label) | ||||
@@ -165,6 +159,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
raise ValueError('not support targets type {}'.format(y_type)) | raise ValueError('not support targets type {}'.format(y_type)) | ||||
raise ValueError('not support for average type {}'.format(average)) | raise ValueError('not support for average type {}'.format(average)) | ||||
def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | ||||
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | ||||
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) | ||||
@@ -178,6 +173,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
y = np.array([1,0,1,0,1,1]) | |||||
print(_label_types(y)) | |||||
y = np.array([1, 0, 1, 0, 1, 1]) | |||||
print(_label_types(y)) |
@@ -1,5 +1,3 @@ | |||||
''' | |||||
""" | |||||
use optimizer from Pytorch | use optimizer from Pytorch | ||||
''' | |||||
from torch.optim import * | |||||
""" |
@@ -7,9 +7,17 @@ from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None): | |||||
for indices in iterator: | |||||
batch_x = [data[idx] for idx in indices] | |||||
def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None): | |||||
"""Batch and Pad data, only for Inference. | |||||
:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples. | |||||
:param use_cuda: bool, whether to use GPU | |||||
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False) | |||||
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||||
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None) | |||||
:return: | |||||
""" | |||||
for batch_x in iterator: | |||||
batch_x = pad(batch_x) | batch_x = pad(batch_x) | ||||
# convert list to tensor | # convert list to tensor | ||||
batch_x = convert_to_torch_tensor(batch_x, use_cuda) | batch_x = convert_to_torch_tensor(batch_x, use_cuda) | ||||
@@ -29,11 +37,11 @@ def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_ | |||||
def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
""" Pad a mini-batch of sequence samples to maximum length of this batch. | |||||
:param batch: list of list | :param batch: list of list | ||||
:param fill: word index to pad, default 0. | :param fill: word index to pad, default 0. | ||||
:return: a padded batch | |||||
:return batch: a padded mini-batch | |||||
""" | """ | ||||
max_length = max([len(x) for x in batch]) | max_length = max([len(x) for x in batch]) | ||||
for idx, sample in enumerate(batch): | for idx, sample in enumerate(batch): | ||||
@@ -42,13 +50,13 @@ def pad(batch, fill=0): | |||||
return batch | return batch | ||||
class Inference(object): | |||||
""" | |||||
This is an interface focusing on predicting output based on trained models. | |||||
class Predictor(object): | |||||
"""An interface for predicting outputs based on trained models. | |||||
It does not care about evaluations of the model, which is different from Tester. | It does not care about evaluations of the model, which is different from Tester. | ||||
This is a high-level model wrapper to be called by FastNLP. | This is a high-level model wrapper to be called by FastNLP. | ||||
This class does not share any operations with Trainer and Tester. | This class does not share any operations with Trainer and Tester. | ||||
Currently, Inference does not support GPU. | |||||
Currently, Predictor does not support GPU. | |||||
""" | """ | ||||
def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
@@ -60,11 +68,11 @@ class Inference(object): | |||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | ||||
def predict(self, network, data): | def predict(self, network, data): | ||||
""" | |||||
Perform inference. | |||||
:param network: | |||||
:param data: two-level lists of strings | |||||
:return result: the model outputs | |||||
"""Perform inference using the trained model. | |||||
:param network: a PyTorch model | |||||
:param data: list of list of strings | |||||
:return: list of list of strings, [num_examples, tag_seq_length] | |||||
""" | """ | ||||
# transform strings into indices | # transform strings into indices | ||||
data = self.prepare_input(data) | data = self.prepare_input(data) | ||||
@@ -73,9 +81,9 @@ class Inference(object): | |||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
self.batch_output.clear() | self.batch_output.clear() | ||||
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
for batch_x in self.make_batch(iterator, data, use_cuda=False): | |||||
for batch_x in self.make_batch(data_iterator, use_cuda=False): | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -90,20 +98,22 @@ class Inference(object): | |||||
network.train() | network.train() | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def make_batch(self, iterator, data, use_cuda): | |||||
def make_batch(self, iterator, use_cuda): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
""" | |||||
Transform two-level list of strings into that of index. | |||||
"""Transform two-level list of strings into that of index. | |||||
:param data: | :param data: | ||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
:return data_index: list of list of int. | |||||
""" | """ | ||||
assert isinstance(data, list) | assert isinstance(data, list) | ||||
data_index = [] | data_index = [] | ||||
@@ -113,10 +123,11 @@ class Inference(object): | |||||
return data_index | return data_index | ||||
def prepare_output(self, data): | def prepare_output(self, data): | ||||
"""Transform list of batch outputs into strings.""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class SeqLabelInfer(Inference): | |||||
class SeqLabelInfer(Predictor): | |||||
""" | """ | ||||
Inference on sequence labeling models. | Inference on sequence labeling models. | ||||
""" | """ | ||||
@@ -127,12 +138,15 @@ class SeqLabelInfer(Inference): | |||||
def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
""" | """ | ||||
This is only for sequence labeling with CRF decoder. | This is only for sequence labeling with CRF decoder. | ||||
:param network: | |||||
:param inputs: | |||||
:return: Tensor | |||||
:param network: a PyTorch model | |||||
:param inputs: tuple of (x, seq_len) | |||||
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||||
after padding. | |||||
seq_len: list of int, the lengths of sequences before padding. | |||||
:return prediction: Tensor of shape [batch_size, max_len] | |||||
""" | """ | ||||
if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | ||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
raise RuntimeError("output_length must be true for sequence modeling.") | |||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
@@ -142,14 +156,14 @@ class SeqLabelInfer(Inference): | |||||
prediction = network.prediction(y, mask) | prediction = network.prediction(y, mask) | ||||
return torch.Tensor(prediction) | return torch.Tensor(prediction) | ||||
def make_batch(self, iterator, data, use_cuda): | |||||
return make_batch(iterator, data, use_cuda, output_length=True) | |||||
def make_batch(self, iterator, use_cuda): | |||||
return make_batch(iterator, use_cuda, output_length=True) | |||||
def prepare_output(self, batch_outputs): | def prepare_output(self, batch_outputs): | ||||
""" | |||||
Transform list of batch outputs into strings. | |||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||||
:return results: 2-D list of strings | |||||
"""Transform list of batch outputs into strings. | |||||
:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length]. | |||||
:return results: 2-D list of strings, shape [num_examples, tag_seq_length] | |||||
""" | """ | ||||
results = [] | results = [] | ||||
for batch in batch_outputs: | for batch in batch_outputs: | ||||
@@ -158,7 +172,7 @@ class SeqLabelInfer(Inference): | |||||
return results | return results | ||||
class ClassificationInfer(Inference): | |||||
class ClassificationInfer(Predictor): | |||||
""" | """ | ||||
Inference on Classification models. | Inference on Classification models. | ||||
""" | """ | ||||
@@ -171,8 +185,8 @@ class ClassificationInfer(Inference): | |||||
logits = network(x) | logits = network(x) | ||||
return logits | return logits | ||||
def make_batch(self, iterator, data, use_cuda): | |||||
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5) | |||||
def make_batch(self, iterator, use_cuda): | |||||
return make_batch(iterator, use_cuda, output_length=False, min_len=5) | |||||
def prepare_output(self, batch_outputs): | def prepare_output(self, batch_outputs): | ||||
""" | """ |
@@ -9,7 +9,7 @@ from fastNLP.modules import utils | |||||
class BaseTester(object): | class BaseTester(object): | ||||
"""docstring for Tester""" | |||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||||
def __init__(self, test_args): | def __init__(self, test_args): | ||||
""" | """ | ||||
@@ -62,8 +62,8 @@ class BaseTester(object): | |||||
step += 1 | step += 1 | ||||
def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
""" | |||||
Save the dev data once it is loaded. Can return directly next time. | |||||
"""Save the dev data once it is loaded. Can return directly next time. | |||||
:param data_path: str, the path to the pickle data for dev | :param data_path: str, the path to the pickle data for dev | ||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
""" | """ | ||||
@@ -73,21 +73,29 @@ class BaseTester(object): | |||||
return self.save_dev_data | return self.save_dev_data | ||||
def mode(self, model, test): | def mode(self, model, test): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | |||||
:param model: a PyTorch model | |||||
:param test: bool, whether in test mode. | |||||
""" | |||||
Action.mode(model, test) | Action.mode(model, test) | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""A forward pass of the model. """ | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
"""Compute evaluation metrics for the model. """ | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@property | @property | ||||
def metrics(self): | def metrics(self): | ||||
"""Return a list of metrics. """ | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def show_matrices(self): | def show_matrices(self): | ||||
""" | |||||
This is called by Trainer to print evaluation on dev set. | |||||
"""This is called by Trainer to print evaluation results on dev set during training. | |||||
:return print_str: str | :return print_str: str | ||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -112,8 +120,17 @@ class SeqLabelTester(BaseTester): | |||||
self.batch_result = None | self.batch_result = None | ||||
def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
"""This is only for sequence labeling with CRF decoder. | |||||
:param network: a PyTorch model | |||||
:param inputs: tuple of (x, seq_len) | |||||
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||||
after padding. | |||||
seq_len: list of int, the lengths of sequences before padding. | |||||
:return y: Tensor of shape [batch_size, max_len] | |||||
""" | |||||
if not isinstance(inputs, tuple): | if not isinstance(inputs, tuple): | ||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
raise RuntimeError("output_length must be true for sequence modeling.") | |||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
@@ -127,6 +144,12 @@ class SeqLabelTester(BaseTester): | |||||
return y | return y | ||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
"""Compute metrics (or loss). | |||||
:param predict: Tensor, [batch_size, max_len, tag_size] | |||||
:param truth: Tensor, [batch_size, max_len] | |||||
:return: | |||||
""" | |||||
batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
loss = self.model.loss(predict, truth, self.mask) / batch_size | loss = self.model.loss(predict, truth, self.mask) / batch_size | ||||
@@ -151,7 +174,7 @@ class SeqLabelTester(BaseTester): | |||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | ||||
def make_batch(self, iterator, data): | def make_batch(self, iterator, data): | ||||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) | |||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | |||||
class ClassificationTester(BaseTester): | class ClassificationTester(BaseTester): | ||||
@@ -171,7 +194,7 @@ class ClassificationTester(BaseTester): | |||||
self.iterator = None | self.iterator = None | ||||
def make_batch(self, iterator, data, max_len=None): | def make_batch(self, iterator, data, max_len=None): | ||||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len) | |||||
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | """Forward through network.""" | ||||
@@ -1,5 +1,6 @@ | |||||
import _pickle | import _pickle | ||||
import os | import os | ||||
import time | |||||
from datetime import timedelta | from datetime import timedelta | ||||
from time import time | from time import time | ||||
@@ -13,10 +14,11 @@ from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
DEFAULT_QUEUE_SIZE = 300 | |||||
class BaseTrainer(object): | class BaseTrainer(object): | ||||
"""Base trainer for all trainers. | |||||
Trainer receives a model and data, and then performs training. | |||||
"""Operations to train a model, including data loading, SGD, and validation. | |||||
Subclasses must implement the following abstract methods: | Subclasses must implement the following abstract methods: | ||||
- define_optimizer | - define_optimizer | ||||
@@ -70,7 +72,7 @@ class BaseTrainer(object): | |||||
else: | else: | ||||
self.model = network | self.model = network | ||||
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||||
data_train = self.load_train_data(self.pickle_path) | |||||
# define tester over dev data | # define tester over dev data | ||||
if self.validate: | if self.validate: | ||||
@@ -82,33 +84,19 @@ class BaseTrainer(object): | |||||
self.define_optimizer() | self.define_optimizer() | ||||
# main training epochs | # main training epochs | ||||
start = time() | |||||
start = time.time() | |||||
n_samples = len(data_train) | n_samples = len(data_train) | ||||
n_batches = n_samples // self.batch_size | n_batches = n_samples // self.batch_size | ||||
n_print = 1 | n_print = 1 | ||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
# turn on network training mode; prepare batch iterator | |||||
# turn on network training mode | |||||
self.mode(network, test=False) | self.mode(network, test=False) | ||||
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||||
# training iterations in one epoch | |||||
step = 0 | |||||
for batch_x, batch_y in self.make_batch(iterator, data_train): | |||||
prediction = self.data_forward(network, batch_x) | |||||
# prepare mini-batch iterator | |||||
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
if step % n_print == 0: | |||||
end = time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||||
epoch, step, loss.data, diff)) | |||||
step += 1 | |||||
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | |||||
if self.validate: | if self.validate: | ||||
validator.test(network) | validator.test(network) | ||||
@@ -120,27 +108,39 @@ class BaseTrainer(object): | |||||
print("[epoch {}]".format(epoch), end=" ") | print("[epoch {}]".format(epoch), end=" ") | ||||
print(validator.show_matrices()) | print(validator.show_matrices()) | ||||
def prepare_input(self, pickle_path): | |||||
def _train_step(self, data_iterator, network, **kwargs): | |||||
"""Training process in one epoch.""" | |||||
step = 0 | |||||
for batch_x, batch_y in self.make_batch(data_iterator): | |||||
prediction = self.data_forward(network, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
if step % kwargs["n_print"] == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||||
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||||
kwargs["epoch"], step, loss.data, diff)) | |||||
step += 1 | |||||
def load_train_data(self, pickle_path): | |||||
""" | """ | ||||
For task-specific processing. | For task-specific processing. | ||||
:param pickle_path: | :param pickle_path: | ||||
:return data_train, data_dev, data_test, embedding: | |||||
:return data_train | |||||
""" | """ | ||||
names = [ | |||||
"data_train.pkl", "data_dev.pkl", | |||||
"data_test.pkl", "embedding.pkl"] | |||||
files = [] | |||||
for name in names: | |||||
file_path = os.path.join(pickle_path, name) | |||||
if os.path.exists(file_path): | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
else: | |||||
data = [] | |||||
files.append(data) | |||||
return tuple(files) | |||||
file_path = os.path.join(pickle_path, "data_train.pkl") | |||||
if os.path.exists(file_path): | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
else: | |||||
raise RuntimeError("cannot find training data {}".format(file_path)) | |||||
return data | |||||
def make_batch(self, iterator, data): | |||||
def make_batch(self, iterator): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def mode(self, network, test): | def mode(self, network, test): | ||||
@@ -219,7 +219,7 @@ class ToyTrainer(BaseTrainer): | |||||
def __init__(self, training_args): | def __init__(self, training_args): | ||||
super(ToyTrainer, self).__init__(training_args) | super(ToyTrainer, self).__init__(training_args) | ||||
def prepare_input(self, data_path): | |||||
def load_train_data(self, data_path): | |||||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | ||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | ||||
return data_train, data_dev, 0, 1 | return data_train, data_dev, 0, 1 | ||||
@@ -267,7 +267,7 @@ class SeqLabelTrainer(BaseTrainer): | |||||
def data_forward(self, network, inputs): | def data_forward(self, network, inputs): | ||||
if not isinstance(inputs, tuple): | if not isinstance(inputs, tuple): | ||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | |||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
@@ -303,8 +303,8 @@ class SeqLabelTrainer(BaseTrainer): | |||||
else: | else: | ||||
return False | return False | ||||
def make_batch(self, iterator, data): | |||||
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda) | |||||
def make_batch(self, iterator): | |||||
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | |||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
return SeqLabelTester(valid_args) | return SeqLabelTester(valid_args) | ||||
@@ -349,8 +349,8 @@ class ClassificationTrainer(BaseTrainer): | |||||
"""Apply gradient.""" | """Apply gradient.""" | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def make_batch(self, iterator, data): | |||||
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda) | |||||
def make_batch(self, iterator): | |||||
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | |||||
def get_acc(self, y_logit, y_true): | def get_acc(self, y_logit, y_true): | ||||
"""Compute accuracy.""" | """Compute accuracy.""" | ||||
@@ -1,4 +1,4 @@ | |||||
from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer | |||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
@@ -91,6 +91,9 @@ class ConfigSection(object): | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | (key, str(type(getattr(self, key))), str(type(value)))) | ||||
setattr(self, key, value) | setattr(self, key, value) | ||||
def __contains__(self, item): | |||||
return item in self.__dict__.keys() | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
config = ConfigLoader('configLoader', 'there is no data') | config = ConfigLoader('configLoader', 'there is no data') | ||||
@@ -1,4 +1,4 @@ | |||||
from loader.base_loader import BaseLoader | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
@@ -1,3 +1,9 @@ | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
def mask_softmax(matrix, mask): | def mask_softmax(matrix, mask): | ||||
if mask is None: | if mask is None: | ||||
result = torch.nn.functional.softmax(matrix, dim=-1) | result = torch.nn.functional.softmax(matrix, dim=-1) | ||||
@@ -15,10 +21,6 @@ def seq_mask(seq_len, max_len): | |||||
""" | """ | ||||
Codes from FudanParser. Not tested. Do not use !!! | Codes from FudanParser. Not tested. Do not use !!! | ||||
""" | """ | ||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
def expand_gt(gt): | def expand_gt(gt): | ||||
@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import Inference | |||||
from fastNLP.core.predictor import Predictor | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
cws_data_path = "/home/zyfeng/data/pku_training.utf8" | cws_data_path = "/home/zyfeng/data/pku_training.utf8" | ||||
@@ -41,7 +41,7 @@ def infer(): | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
# Inference interface | # Inference interface | ||||
infer = Inference(pickle_path) | |||||
infer = Predictor(pickle_path) | |||||
results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
print(results) | print(results) | ||||
@@ -1 +1,3 @@ | |||||
import fastNLP | |||||
__all__ = ["fastNLP"] |
@@ -3,7 +3,7 @@ import os | |||||
import torch | import torch | ||||
from fastNLP.core.inference import SeqLabelInfer | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import SeqLabelInfer | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
data_name = "people.txt" | data_name = "people.txt" | ||||
data_path = "data_for_tests/people.txt" | data_path = "data_for_tests/people.txt" | ||||
@@ -112,5 +112,5 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | |||||
# infer() | |||||
# train_and_test() | |||||
infer() |
@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import Inference | |||||
from fastNLP.core.predictor import Predictor | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
# cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | ||||
@@ -51,7 +51,7 @@ def infer(): | |||||
""" | """ | ||||
# Inference interface | # Inference interface | ||||
infer = Inference(pickle_path) | |||||
infer = Predictor(pickle_path) | |||||
results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
print(results) | print(results) | ||||
@@ -2,8 +2,10 @@ | |||||
# encoding: utf-8 | # encoding: utf-8 | ||||
import os | import os | ||||
import sys | |||||
from fastNLP.core.inference import ClassificationInfer | |||||
sys.path.append("..") | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | from fastNLP.loader.dataset_loader import ClassDatasetLoader | ||||