- rename "POSTrainer", "POSTester" to "SeqLabelTrainer", "SeqLabelTester" - Trainer & Tester have NO relation with Action - Inference owns independent "make_batch" & "data_forward" - Conversion to Tensor & go into cuda are done in "make_batch" - "make_batch" support maximum/minimum lengthtags/v0.1.0
@@ -5,6 +5,7 @@ | |||||
from collections import Counter | from collections import Counter | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
class Action(object): | class Action(object): | ||||
@@ -21,7 +22,7 @@ class Action(object): | |||||
super(Action, self).__init__() | super(Action, self).__init__() | ||||
@staticmethod | @staticmethod | ||||
def make_batch(iterator, data, output_length=True): | |||||
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None): | |||||
"""Batch and Pad data. | """Batch and Pad data. | ||||
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | ||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | :param data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
@@ -31,7 +32,9 @@ class Action(object): | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | ||||
... | ... | ||||
] | ] | ||||
:param use_cuda: bool | |||||
:param output_length: whether to output the original length of the sequence before padding. | :param output_length: whether to output the original length of the sequence before padding. | ||||
:param max_len: int, maximum sequence length | |||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | :return (batch_x, seq_len): tuple of two elements, if output_length is true. | ||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | ||||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | seq_len: list. The length of the pre-padded sequence, if output_length is True. | ||||
@@ -43,13 +46,25 @@ class Action(object): | |||||
batch = [data[idx] for idx in indices] | batch = [data[idx] for idx in indices] | ||||
batch_x = [sample[0] for sample in batch] | batch_x = [sample[0] for sample in batch] | ||||
batch_y = [sample[1] for sample in batch] | batch_y = [sample[1] for sample in batch] | ||||
batch_x_pad = Action.pad(batch_x) | |||||
batch_y_pad = Action.pad(batch_y) | |||||
batch_x = Action.pad(batch_x) | |||||
# pad batch_y only if it is a 2-level list | |||||
if len(batch_y) > 0 and isinstance(batch_y[0], list): | |||||
batch_y = Action.pad(batch_y) | |||||
# convert list to tensor | |||||
batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||||
batch_y = convert_to_torch_tensor(batch_y, use_cuda) | |||||
# trim data to max_len | |||||
if max_len is not None and batch_x.size(1) > max_len: | |||||
batch_x = batch_x[:, :max_len] | |||||
if output_length: | if output_length: | ||||
seq_len = [len(x) for x in batch_x] | seq_len = [len(x) for x in batch_x] | ||||
yield (batch_x_pad, seq_len), batch_y_pad | |||||
yield (batch_x, seq_len), batch_y | |||||
else: | else: | ||||
yield batch_x_pad, batch_y_pad | |||||
yield batch_x, batch_y | |||||
@staticmethod | @staticmethod | ||||
def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
@@ -78,6 +93,20 @@ class Action(object): | |||||
model.train() | model.train() | ||||
def convert_to_torch_tensor(data_list, use_cuda): | |||||
""" | |||||
convert lists into (cuda) Tensors | |||||
:param data_list: 2-level lists | |||||
:param use_cuda: bool | |||||
:param reqired_grad: bool | |||||
:return: PyTorch Tensor of shape [batch_size, max_seq_len] | |||||
""" | |||||
data_list = torch.Tensor(data_list).long() | |||||
if torch.cuda.is_available() and use_cuda: | |||||
data_list = data_list.cuda() | |||||
return data_list | |||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
""" | """ | ||||
Perform k-means on 1-D data. | Perform k-means on 1-D data. | ||||
@@ -2,16 +2,53 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.action import Batchifier, SequentialSampler | from fastNLP.core.action import Batchifier, SequentialSampler | ||||
from fastNLP.core.action import convert_to_torch_tensor | |||||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | ||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None): | |||||
for indices in iterator: | |||||
batch_x = [data[idx] for idx in indices] | |||||
batch_x = pad(batch_x) | |||||
# convert list to tensor | |||||
batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||||
# trim data to max_len | |||||
if max_len is not None and batch_x.size(1) > max_len: | |||||
batch_x = batch_x[:, :max_len] | |||||
if min_len is not None and batch_x.size(1) < min_len: | |||||
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x) | |||||
batch_x = torch.cat((batch_x, pad_tensor), 1) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
yield tuple([batch_x, seq_len]) | |||||
else: | |||||
yield batch_x | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | |||||
class Inference(object): | class Inference(object): | ||||
""" | """ | ||||
This is an interface focusing on predicting output based on trained models. | This is an interface focusing on predicting output based on trained models. | ||||
It does not care about evaluations of the model, which is different from Tester. | It does not care about evaluations of the model, which is different from Tester. | ||||
This is a high-level model wrapper to be called by FastNLP. | This is a high-level model wrapper to be called by FastNLP. | ||||
This class does not share any operations with Trainer and Tester. | This class does not share any operations with Trainer and Tester. | ||||
Currently, Inference does not support GPU. | |||||
""" | """ | ||||
def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
@@ -38,10 +75,7 @@ class Inference(object): | |||||
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | ||||
num_iter = len(data) // self.batch_size | |||||
for step in range(num_iter): | |||||
batch_x = self.make_batch(iterator, data) | |||||
for batch_x in self.make_batch(iterator, data, use_cuda=False): | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -54,35 +88,12 @@ class Inference(object): | |||||
network.eval() | network.eval() | ||||
else: | else: | ||||
network.train() | network.train() | ||||
self.batch_output.clear() | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@staticmethod | |||||
def make_batch(iterator, data, output_length=True): | |||||
indices = next(iterator) | |||||
batch_x = [data[idx] for idx in indices] | |||||
batch_x_pad = Inference.pad(batch_x) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return [batch_x_pad, seq_len] | |||||
else: | |||||
return batch_x_pad | |||||
@staticmethod | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | |||||
def make_batch(self, iterator, data, use_cuda): | |||||
raise NotImplementedError | |||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
""" | """ | ||||
@@ -101,17 +112,8 @@ class Inference(object): | |||||
data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | ||||
return data_index | return data_index | ||||
def prepare_output(self, batch_outputs): | |||||
""" | |||||
Transform list of batch outputs into strings. | |||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||||
:return: | |||||
""" | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([self.index2label[int(x)] for x in example]) | |||||
return results | |||||
def prepare_output(self, data): | |||||
raise NotImplementedError | |||||
class SeqLabelInfer(Inference): | class SeqLabelInfer(Inference): | ||||
@@ -133,10 +135,53 @@ class SeqLabelInfer(Inference): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | ||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
x = torch.Tensor(x).long() | |||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = utils.seq_mask(seq_len, max_len) | mask = utils.seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | mask = mask.byte().view(batch_size, max_len) | ||||
y = network(x) | y = network(x) | ||||
prediction = network.prediction(y, mask) | prediction = network.prediction(y, mask) | ||||
return torch.Tensor(prediction) | |||||
return torch.Tensor(prediction, required_grad=False) | |||||
def make_batch(self, iterator, data, use_cuda): | |||||
return make_batch(iterator, data, use_cuda, output_length=True) | |||||
def prepare_output(self, batch_outputs): | |||||
""" | |||||
Transform list of batch outputs into strings. | |||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||||
:return: | |||||
""" | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([self.index2label[int(x)] for x in example]) | |||||
return results | |||||
class ClassificationInfer(Inference): | |||||
""" | |||||
Inference on Classification models. | |||||
""" | |||||
def __init__(self, pickle_path): | |||||
super(ClassificationInfer, self).__init__(pickle_path) | |||||
def data_forward(self, network, x): | |||||
"""Forward through network.""" | |||||
logits = network(x) | |||||
return logits | |||||
def make_batch(self, iterator, data, use_cuda): | |||||
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5) | |||||
def prepare_output(self, batch_outputs): | |||||
""" | |||||
Transform list of batch outputs into strings. | |||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. | |||||
:return: | |||||
""" | |||||
results = [] | |||||
for batch_out in batch_outputs: | |||||
idx = np.argmax(batch_out.detach().numpy()) | |||||
results.append(self.index2label[idx]) | |||||
return results |
@@ -1,5 +1,4 @@ | |||||
import _pickle | import _pickle | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -9,15 +8,14 @@ from fastNLP.core.action import RandomSampler, Batchifier | |||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
class BaseTester(Action): | |||||
class BaseTester(object): | |||||
"""docstring for Tester""" | """docstring for Tester""" | ||||
def __init__(self, test_args, action=None): | |||||
def __init__(self, test_args): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(BaseTester, self).__init__() | super(BaseTester, self).__init__() | ||||
self.action = action if action is not None else Action() | |||||
self.validate_in_training = test_args["validate_in_training"] | self.validate_in_training = test_args["validate_in_training"] | ||||
self.save_dev_data = None | self.save_dev_data = None | ||||
self.save_output = test_args["save_output"] | self.save_output = test_args["save_output"] | ||||
@@ -39,16 +37,23 @@ class BaseTester(Action): | |||||
else: | else: | ||||
self.model = network | self.model = network | ||||
# no backward setting for model | |||||
for param in network.parameters(): | |||||
param.requires_grad = False | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.action.mode(network, test=True) | |||||
self.mode(network, test=True) | |||||
self.eval_history.clear() | self.eval_history.clear() | ||||
self.batch_output.clear() | self.batch_output.clear() | ||||
dev_data = self.prepare_input(self.pickle_path) | dev_data = self.prepare_input(self.pickle_path) | ||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | ||||
n_batches = len(dev_data) // self.batch_size | |||||
n_print = 1 | |||||
step = 0 | |||||
for batch_x, batch_y in self.action.make_batch(iterator, dev_data): | |||||
for batch_x, batch_y in self.make_batch(iterator, dev_data): | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -58,6 +63,7 @@ class BaseTester(Action): | |||||
self.batch_output.append(prediction) | self.batch_output.append(prediction) | ||||
if self.save_loss: | if self.save_loss: | ||||
self.eval_history.append(eval_results) | self.eval_history.append(eval_results) | ||||
step += 1 | |||||
def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
""" | """ | ||||
@@ -70,6 +76,9 @@ class BaseTester(Action): | |||||
self.save_dev_data = data_dev | self.save_dev_data = data_dev | ||||
return self.save_dev_data | return self.save_dev_data | ||||
def mode(self, model, test): | |||||
Action.mode(model, test) | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -87,17 +96,20 @@ class BaseTester(Action): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def make_batch(self, iterator, data): | |||||
raise NotImplementedError | |||||
class POSTester(BaseTester): | |||||
class SeqLabelTester(BaseTester): | |||||
""" | """ | ||||
Tester for sequence labeling. | Tester for sequence labeling. | ||||
""" | """ | ||||
def __init__(self, test_args, action=None): | |||||
def __init__(self, test_args): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(POSTester, self).__init__(test_args, action) | |||||
super(SeqLabelTester, self).__init__(test_args) | |||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.batch_result = None | self.batch_result = None | ||||
@@ -107,13 +119,10 @@ class POSTester(BaseTester): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | ||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
x = torch.Tensor(x).long() | |||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = utils.seq_mask(seq_len, max_len) | mask = utils.seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | mask = mask.byte().view(batch_size, max_len) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
x = x.cuda() | |||||
mask = mask.cuda() | mask = mask.cuda() | ||||
self.mask = mask | self.mask = mask | ||||
@@ -121,9 +130,6 @@ class POSTester(BaseTester): | |||||
return y | return y | ||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
truth = torch.Tensor(truth) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
truth = truth.cuda() | |||||
batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
loss = self.model.loss(predict, truth, self.mask) / batch_size | loss = self.model.loss(predict, truth, self.mask) / batch_size | ||||
@@ -147,8 +153,11 @@ class POSTester(BaseTester): | |||||
loss, accuracy = self.metrics() | loss, accuracy = self.metrics() | ||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | ||||
def make_batch(self, iterator, data): | |||||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) | |||||
class ClassTester(BaseTester): | |||||
class ClassificationTester(BaseTester): | |||||
"""Tester for classification.""" | """Tester for classification.""" | ||||
def __init__(self, test_args): | def __init__(self, test_args): | ||||
@@ -156,7 +165,7 @@ class ClassTester(BaseTester): | |||||
:param test_args: a dict-like object that has __getitem__ method, \ | :param test_args: a dict-like object that has __getitem__ method, \ | ||||
can be accessed by "test_args["key_str"]" | can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
# super(ClassTester, self).__init__() | |||||
super(ClassificationTester, self).__init__(test_args) | |||||
self.pickle_path = test_args["pickle_path"] | self.pickle_path = test_args["pickle_path"] | ||||
self.save_dev_data = None | self.save_dev_data = None | ||||
@@ -164,111 +173,8 @@ class ClassTester(BaseTester): | |||||
self.mean_loss = None | self.mean_loss = None | ||||
self.iterator = None | self.iterator = None | ||||
if "test_name" in test_args: | |||||
self.test_name = test_args["test_name"] | |||||
else: | |||||
self.test_name = "data_test.pkl" | |||||
if "validate_in_training" in test_args: | |||||
self.validate_in_training = test_args["validate_in_training"] | |||||
else: | |||||
self.validate_in_training = False | |||||
if "save_output" in test_args: | |||||
self.save_output = test_args["save_output"] | |||||
else: | |||||
self.save_output = False | |||||
if "save_loss" in test_args: | |||||
self.save_loss = test_args["save_loss"] | |||||
else: | |||||
self.save_loss = True | |||||
if "batch_size" in test_args: | |||||
self.batch_size = test_args["batch_size"] | |||||
else: | |||||
self.batch_size = 50 | |||||
if "use_cuda" in test_args: | |||||
self.use_cuda = test_args["use_cuda"] | |||||
else: | |||||
self.use_cuda = True | |||||
if "max_len" in test_args: | |||||
self.max_len = test_args["max_len"] | |||||
else: | |||||
self.max_len = None | |||||
self.model = None | |||||
self.eval_history = [] | |||||
self.batch_output = [] | |||||
def test(self, network): | |||||
# prepare model | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = network.cuda() | |||||
else: | |||||
self.model = network | |||||
# no backward setting for model | |||||
for param in self.model.parameters(): | |||||
param.requires_grad = False | |||||
# turn on the testing mode; clean up the history | |||||
self.mode(network, test=True) | |||||
# prepare test data | |||||
data_test = self.prepare_input(self.pickle_path, self.test_name) | |||||
# data generator | |||||
self.iterator = iter(Batchifier( | |||||
RandomSampler(data_test), self.batch_size, drop_last=False)) | |||||
# test | |||||
n_batches = len(data_test) // self.batch_size | |||||
n_print = n_batches // 10 | |||||
step = 0 | |||||
for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len): | |||||
prediction = self.data_forward(network, batch_x) | |||||
eval_results = self.evaluate(prediction, batch_y) | |||||
if self.save_output: | |||||
self.batch_output.append(prediction) | |||||
if self.save_loss: | |||||
self.eval_history.append(eval_results) | |||||
if step % n_print == 0: | |||||
print("step: {:>5}".format(step)) | |||||
step += 1 | |||||
def prepare_input(self, data_dir, file_name): | |||||
"""Prepare data.""" | |||||
file_path = os.path.join(data_dir, file_name) | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
return data | |||||
def make_batch(self, data, max_len=None): | |||||
"""Batch and pad data.""" | |||||
for indices in self.iterator: | |||||
# generate batch and pad | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x = self.pad(batch_x) | |||||
# convert to tensor | |||||
batch_x = torch.tensor(batch_x, dtype=torch.long) | |||||
batch_y = torch.tensor(batch_y, dtype=torch.long) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
batch_x = batch_x.cuda() | |||||
batch_y = batch_y.cuda() | |||||
# trim data to max_len | |||||
if max_len is not None and batch_x.size(1) > max_len: | |||||
batch_x = batch_x[:, :max_len] | |||||
yield batch_x, batch_y | |||||
def make_batch(self, iterator, data, max_len=None): | |||||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len) | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | """Forward through network.""" | ||||
@@ -289,10 +195,3 @@ class ClassTester(BaseTester): | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | acc = float(torch.sum(y_pred == y_true)) / len(y_true) | ||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | ||||
def mode(self, model, test=True): | |||||
"""TODO: combine this function with Trainer ?? """ | |||||
if test: | |||||
model.eval() | |||||
else: | |||||
model.train() | |||||
self.eval_history.clear() |
@@ -9,12 +9,12 @@ import torch.nn as nn | |||||
from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
from fastNLP.core.action import RandomSampler, Batchifier | from fastNLP.core.action import RandomSampler, Batchifier | ||||
from fastNLP.core.tester import POSTester | |||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
class BaseTrainer(Action): | |||||
class BaseTrainer(object): | |||||
"""Base trainer for all trainers. | """Base trainer for all trainers. | ||||
Trainer receives a model and data, and then performs training. | Trainer receives a model and data, and then performs training. | ||||
@@ -24,10 +24,9 @@ class BaseTrainer(Action): | |||||
- get_loss | - get_loss | ||||
""" | """ | ||||
def __init__(self, train_args, action=None): | |||||
def __init__(self, train_args): | |||||
""" | """ | ||||
:param train_args: dict of (key, value), or dict-like object. key is str. | :param train_args: dict of (key, value), or dict-like object. key is str. | ||||
:param action: (optional) an Action object that wrap most operations shared by Trainer, Tester, and Inference. | |||||
The base trainer requires the following keys: | The base trainer requires the following keys: | ||||
- epochs: int, the number of epochs in training | - epochs: int, the number of epochs in training | ||||
@@ -36,7 +35,6 @@ class BaseTrainer(Action): | |||||
- pickle_path: str, the path to pickle files for pre-processing | - pickle_path: str, the path to pickle files for pre-processing | ||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.action = action if action is not None else Action() | |||||
self.n_epochs = train_args["epochs"] | self.n_epochs = train_args["epochs"] | ||||
self.batch_size = train_args["batch_size"] | self.batch_size = train_args["batch_size"] | ||||
self.pickle_path = train_args["pickle_path"] | self.pickle_path = train_args["pickle_path"] | ||||
@@ -79,7 +77,7 @@ class BaseTrainer(Action): | |||||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda} | "use_cuda": self.use_cuda} | ||||
validator = POSTester(default_valid_args, self.action) | |||||
validator = self._create_validator(default_valid_args) | |||||
self.define_optimizer() | self.define_optimizer() | ||||
@@ -92,12 +90,12 @@ class BaseTrainer(Action): | |||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
# turn on network training mode; prepare batch iterator | # turn on network training mode; prepare batch iterator | ||||
self.action.mode(network, test=False) | |||||
self.mode(network, test=False) | |||||
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | ||||
# training iterations in one epoch | # training iterations in one epoch | ||||
step = 0 | step = 0 | ||||
for batch_x, batch_y in self.action.make_batch(iterator, data_train, output_length=True): | |||||
for batch_x, batch_y in self.make_batch(iterator, data_train): | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -142,6 +140,12 @@ class BaseTrainer(Action): | |||||
files.append(data) | files.append(data) | ||||
return tuple(files) | return tuple(files) | ||||
def make_batch(self, iterator, data): | |||||
raise NotImplementedError | |||||
def mode(self, network, test): | |||||
Action.mode(network, test) | |||||
def define_optimizer(self): | def define_optimizer(self): | ||||
""" | """ | ||||
Define framework-specific optimizer specified by the models. | Define framework-specific optimizer specified by the models. | ||||
@@ -203,6 +207,9 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | ||||
def _create_validator(self, valid_args): | |||||
raise NotImplementedError | |||||
class ToyTrainer(BaseTrainer): | class ToyTrainer(BaseTrainer): | ||||
""" | """ | ||||
@@ -217,12 +224,6 @@ class ToyTrainer(BaseTrainer): | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | ||||
return data_train, data_dev, 0, 1 | return data_train, data_dev, 0, 1 | ||||
def mode(self, test=False): | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
return network(x) | return network(x) | ||||
@@ -246,8 +247,8 @@ class SeqLabelTrainer(BaseTrainer): | |||||
""" | """ | ||||
def __init__(self, train_args, action=None): | |||||
super(SeqLabelTrainer, self).__init__(train_args, action) | |||||
def __init__(self, train_args): | |||||
super(SeqLabelTrainer, self).__init__(train_args) | |||||
self.vocab_size = train_args["vocab_size"] | self.vocab_size = train_args["vocab_size"] | ||||
self.num_classes = train_args["num_classes"] | self.num_classes = train_args["num_classes"] | ||||
self.max_len = None | self.max_len = None | ||||
@@ -269,14 +270,12 @@ class SeqLabelTrainer(BaseTrainer): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | ||||
# unpack the returned value from make_batch | # unpack the returned value from make_batch | ||||
x, seq_len = inputs[0], inputs[1] | x, seq_len = inputs[0], inputs[1] | ||||
x = torch.Tensor(x).long() | |||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = utils.seq_mask(seq_len, max_len) | mask = utils.seq_mask(seq_len, max_len) | ||||
mask = mask.byte().view(batch_size, max_len) | mask = mask.byte().view(batch_size, max_len) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
x = x.cuda() | |||||
mask = mask.cuda() | mask = mask.cuda() | ||||
self.mask = mask | self.mask = mask | ||||
@@ -290,9 +289,6 @@ class SeqLabelTrainer(BaseTrainer): | |||||
:param truth: ground truth label vector, [batch_size, max_len] | :param truth: ground truth label vector, [batch_size, max_len] | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
truth = torch.Tensor(truth) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
truth = truth.cuda() | |||||
batch_size, max_len = predict.size(0), predict.size(1) | batch_size, max_len = predict.size(0), predict.size(1) | ||||
assert truth.shape == (batch_size, max_len) | assert truth.shape == (batch_size, max_len) | ||||
@@ -307,32 +303,18 @@ class SeqLabelTrainer(BaseTrainer): | |||||
else: | else: | ||||
return False | return False | ||||
def make_batch(self, iterator, data): | |||||
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda) | |||||
class LanguageModelTrainer(BaseTrainer): | |||||
""" | |||||
Trainer for Language Model | |||||
""" | |||||
def __init__(self, train_args): | |||||
super(LanguageModelTrainer, self).__init__(train_args) | |||||
def prepare_input(self, data_path): | |||||
pass | |||||
def _create_validator(self, valid_args): | |||||
return SeqLabelTester(valid_args) | |||||
class ClassTrainer(BaseTrainer): | |||||
class ClassificationTrainer(BaseTrainer): | |||||
"""Trainer for classification.""" | """Trainer for classification.""" | ||||
def __init__(self, train_args, action=None): | |||||
super(ClassTrainer, self).__init__(train_args, action) | |||||
self.n_epochs = train_args["epochs"] | |||||
self.batch_size = train_args["batch_size"] | |||||
self.pickle_path = train_args["pickle_path"] | |||||
if "validate" in train_args: | |||||
self.validate = train_args["validate"] | |||||
else: | |||||
self.validate = False | |||||
def __init__(self, train_args): | |||||
super(ClassificationTrainer, self).__init__(train_args) | |||||
if "learn_rate" in train_args: | if "learn_rate" in train_args: | ||||
self.learn_rate = train_args["learn_rate"] | self.learn_rate = train_args["learn_rate"] | ||||
else: | else: | ||||
@@ -341,15 +323,11 @@ class ClassTrainer(BaseTrainer): | |||||
self.momentum = train_args["momentum"] | self.momentum = train_args["momentum"] | ||||
else: | else: | ||||
self.momentum = 0.9 | self.momentum = 0.9 | ||||
if "use_cuda" in train_args: | |||||
self.use_cuda = train_args["use_cuda"] | |||||
else: | |||||
self.use_cuda = True | |||||
self.model = None | |||||
self.iterator = None | self.iterator = None | ||||
self.loss_func = None | self.loss_func = None | ||||
self.optimizer = None | self.optimizer = None | ||||
self.best_accuracy = 0 | |||||
def define_loss(self): | def define_loss(self): | ||||
self.loss_func = nn.CrossEntropyLoss() | self.loss_func = nn.CrossEntropyLoss() | ||||
@@ -365,9 +343,6 @@ class ClassTrainer(BaseTrainer): | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | """Forward through network.""" | ||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
logits = network(x) | logits = network(x) | ||||
return logits | return logits | ||||
@@ -380,31 +355,21 @@ class ClassTrainer(BaseTrainer): | |||||
"""Apply gradient.""" | """Apply gradient.""" | ||||
self.optimizer.step() | self.optimizer.step() | ||||
""" | |||||
def make_batch(self, data): | |||||
for indices in self.iterator: | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x = self.pad(batch_x) | |||||
batch_x = torch.Tensor(batch_x).long() | |||||
batch_y = torch.Tensor(batch_y).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
batch_x = batch_x.cuda() | |||||
batch_y = batch_y.cuda() | |||||
yield batch_x, batch_y | |||||
""" | |||||
def make_batch(self, iterator, data): | |||||
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda) | |||||
def get_acc(self, y_logit, y_true): | def get_acc(self, y_logit, y_true): | ||||
"""Compute accuracy.""" | """Compute accuracy.""" | ||||
y_pred = torch.argmax(y_logit, dim=-1) | y_pred = torch.argmax(y_logit, dim=-1) | ||||
return int(torch.sum(y_true == y_pred)) / len(y_true) | return int(torch.sum(y_true == y_pred)) / len(y_true) | ||||
def best_eval_result(self, validator): | |||||
_, _, accuracy = validator.metrics() | |||||
if accuracy > self.best_accuracy: | |||||
self.best_accuracy = accuracy | |||||
return True | |||||
else: | |||||
return False | |||||
if __name__ == "__name__": | |||||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | |||||
trainer = BaseTrainer(train_args) | |||||
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | |||||
trainer.make_batch(data=data_train) | |||||
def _create_validator(self, valid_args): | |||||
return ClassificationTester(valid_args) |
@@ -1,13 +1,14 @@ | |||||
# python: 3.6 | # python: 3.6 | ||||
# encoding: utf-8 | # encoding: utf-8 | ||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
# import torch.nn.functional as F | # import torch.nn.functional as F | ||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | ||||
class CNNText(BaseModel): | |||||
class CNNText(torch.nn.Module): | |||||
""" | """ | ||||
Text classification model by character CNN, the implementation of paper | Text classification model by character CNN, the implementation of paper | ||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence | 'Yoon Kim. 2014. Convolution Neural Networks for Sentence | ||||
@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | from fastNLP.loader.preprocess import POSPreprocess, load_pickle | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import POSTester | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import Inference | from fastNLP.core.inference import Inference | ||||
@@ -96,7 +96,7 @@ def test(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# Tester | # Tester | ||||
tester = POSTester(test_args) | |||||
tester = SeqLabelTester(test_args) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | tester.test(model) | ||||
@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | from fastNLP.loader.preprocess import POSPreprocess, load_pickle | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import POSTester | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import SeqLabelInfer | from fastNLP.core.inference import SeqLabelInfer | ||||
@@ -101,7 +101,7 @@ def train_and_test(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# Tester | # Tester | ||||
tester = POSTester(test_args) | |||||
tester = SeqLabelTester(test_args) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | tester.test(model) | ||||
@@ -112,5 +112,5 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | |||||
# train_and_test() | |||||
infer() | infer() |
@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | from fastNLP.loader.preprocess import POSPreprocess, load_pickle | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import POSTester | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import Inference | from fastNLP.core.inference import Inference | ||||
@@ -101,7 +101,7 @@ def train_test(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# Tester | # Tester | ||||
tester = POSTester(test_args) | |||||
tester = SeqLabelTester(test_args) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | tester.test(model) | ||||
@@ -1,4 +1,4 @@ | |||||
from fastNLP.core.tester import POSTester | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess | from fastNLP.loader.preprocess import POSPreprocess | ||||
@@ -26,7 +26,7 @@ def foo(): | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | ||||
"use_cuda": True} | "use_cuda": True} | ||||
validator = POSTester(valid_args) | |||||
validator = SeqLabelTester(valid_args) | |||||
print("start validation.") | print("start validation.") | ||||
validator.test(model) | validator.test(model) | ||||
@@ -3,16 +3,45 @@ | |||||
import os | import os | ||||
from fastNLP.core.trainer import ClassTrainer | |||||
from fastNLP.core.inference import ClassificationInfer | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | from fastNLP.loader.dataset_loader import ClassDatasetLoader | ||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.loader.preprocess import ClassPreprocess | from fastNLP.loader.preprocess import ClassPreprocess | ||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.saver.model_saver import ModelSaver | |||||
if __name__ == "__main__": | |||||
data_dir = "./data_for_tests/" | |||||
train_file = 'text_classify.txt' | |||||
model_name = "model_class.pkl" | |||||
data_dir = "./data_for_tests/" | |||||
train_file = 'text_classify.txt' | |||||
model_name = "model_class.pkl" | |||||
def infer(): | |||||
# load dataset | |||||
print("Loading data...") | |||||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||||
data = ds_loader.load() | |||||
unlabeled_data = [x[0] for x in data] | |||||
# pre-process data | |||||
pre = ClassPreprocess(data_dir) | |||||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||||
print("vocabulary size:", vocab_size) | |||||
print("number of classes:", n_classes) | |||||
# construct model | |||||
print("Building model...") | |||||
cnn = CNNText(class_num=n_classes, embed_num=vocab_size) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") | |||||
print("model loaded!") | |||||
infer = ClassificationInfer(data_dir) | |||||
results = infer.predict(cnn, unlabeled_data) | |||||
print(results) | |||||
def train(): | |||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | ||||
@@ -40,5 +69,16 @@ if __name__ == "__main__": | |||||
"model_saved_path": "./data_for_tests/", | "model_saved_path": "./data_for_tests/", | ||||
"use_cuda": True | "use_cuda": True | ||||
} | } | ||||
trainer = ClassTrainer(train_args) | |||||
trainer = ClassificationTrainer(train_args) | |||||
trainer.train(cnn) | trainer.train(cnn) | ||||
print("Training finished!") | |||||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
saver.save_pytorch(cnn) | |||||
print("Model saved!") | |||||
if __name__ == "__main__": | |||||
# train() | |||||
infer() |