- Action collects shared operations: data_forward, mode, pad, make_batch - Trainer and Tester receives Action as a parameter - seq_labeling works in such settingtags/v0.1.0
@@ -1,16 +1,129 @@ | |||||
""" | |||||
This file defines Action(s) and sample methods. | |||||
""" | |||||
from collections import Counter | from collections import Counter | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
import _pickle | |||||
class Action(object): | class Action(object): | ||||
""" | """ | ||||
base class for Trainer and Tester | |||||
Operations shared by Trainer, Tester, and Inference. | |||||
This is designed for reducing replicate codes. | |||||
- prepare_input: data preparation before a forward pass. | |||||
- make_batch: produce a min-batch of data. @staticmethod | |||||
- pad: padding method used in sequence modeling. @staticmethod | |||||
- mode: change network mode for either train or test. (for PyTorch) @staticmethod | |||||
- data_forward: a forward pass of the network. | |||||
The base Action shall define operations shared by as much task-specific Actions as possible. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(Action, self).__init__() | super(Action, self).__init__() | ||||
@staticmethod | |||||
def make_batch(iterator, data, output_length=True): | |||||
""" | |||||
1. Perform batching from data and produce a batch of training data. | |||||
2. Add padding. | |||||
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:param output_length: whether to output the original length of the sequence before padding. | |||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
return batch_x and batch_y, if output_length is False | |||||
""" | |||||
indices = next(iterator) | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x_pad = Action.pad(batch_x) | |||||
batch_y_pad = Action.pad(batch_y) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return (batch_x_pad, seq_len), batch_y_pad | |||||
else: | |||||
return batch_x_pad, batch_y_pad | |||||
@staticmethod | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length of this batch. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | |||||
@staticmethod | |||||
def mode(model, test=False): | |||||
""" | |||||
Train mode or Test mode. This is for PyTorch currently. | |||||
:param model: | |||||
:param test: | |||||
""" | |||||
if test: | |||||
model.eval() | |||||
else: | |||||
model.train() | |||||
def data_forward(self, network, x): | |||||
""" | |||||
Forward pass of the data. | |||||
:param network: a model | |||||
:param x: input feature matrix and label vector | |||||
:return: output by the models | |||||
For PyTorch, just do "network(*x)" | |||||
""" | |||||
raise NotImplementedError | |||||
class SeqLabelAction(Action): | |||||
def __init__(self, action_args): | |||||
""" | |||||
Define task-specific member variables. | |||||
:param action_args: | |||||
""" | |||||
super(SeqLabelAction, self).__init__() | |||||
self.max_len = None | |||||
self.mask = None | |||||
self.best_accuracy = 0.0 | |||||
self.use_cuda = action_args["use_cuda"] | |||||
self.seq_len = None | |||||
self.batch_size = None | |||||
def data_forward(self, network, inputs): | |||||
# unpack the returned value from make_batch | |||||
if isinstance(inputs, tuple): | |||||
x = inputs[0] | |||||
self.seq_len = inputs[1] | |||||
else: | |||||
x = inputs | |||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
y = network(x) | |||||
return y | |||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
""" | """ | ||||
@@ -11,11 +11,12 @@ from fastNLP.core.action import RandomSampler, Batchifier | |||||
class BaseTester(Action): | class BaseTester(Action): | ||||
"""docstring for Tester""" | """docstring for Tester""" | ||||
def __init__(self, test_args): | |||||
def __init__(self, test_args, action): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(BaseTester, self).__init__() | super(BaseTester, self).__init__() | ||||
self.action = action | |||||
self.validate_in_training = test_args["validate_in_training"] | self.validate_in_training = test_args["validate_in_training"] | ||||
self.save_dev_data = None | self.save_dev_data = None | ||||
self.save_output = test_args["save_output"] | self.save_output = test_args["save_output"] | ||||
@@ -38,18 +39,21 @@ class BaseTester(Action): | |||||
self.model = network | self.model = network | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | |||||
self.action.mode(network, test=True) | |||||
self.eval_history.clear() | |||||
self.batch_output.clear() | |||||
dev_data = self.prepare_input(self.pickle_path) | dev_data = self.prepare_input(self.pickle_path) | ||||
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||||
num_iter = len(dev_data) // self.batch_size | num_iter = len(dev_data) // self.batch_size | ||||
for step in range(num_iter): | for step in range(num_iter): | ||||
batch_x, batch_y = self.make_batch(dev_data) | |||||
batch_x, batch_y = self.action.make_batch(iterator, dev_data) | |||||
prediction = self.action.data_forward(network, batch_x) | |||||
prediction = self.data_forward(network, batch_x) | |||||
eval_results = self.evaluate(prediction, batch_y) | eval_results = self.evaluate(prediction, batch_y) | ||||
if self.save_output: | if self.save_output: | ||||
@@ -64,53 +68,10 @@ class BaseTester(Action): | |||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
""" | """ | ||||
if self.save_dev_data is None: | if self.save_dev_data is None: | ||||
data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||||
self.save_dev_data = data_dev | self.save_dev_data = data_dev | ||||
return self.save_dev_data | return self.save_dev_data | ||||
def make_batch(self, data, output_length=True): | |||||
""" | |||||
1. Perform batching from data and produce a batch of training data. | |||||
2. Add padding. | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
""" | |||||
indices = next(self.iterator) | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x_pad = self.pad(batch_x) | |||||
batch_y_pad = self.pad(batch_y) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return (batch_x_pad, seq_len), batch_y_pad | |||||
else: | |||||
return batch_x_pad, batch_y_pad | |||||
@staticmethod | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | |||||
def data_forward(self, network, data): | |||||
raise NotImplementedError | |||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -118,14 +79,6 @@ class BaseTester(Action): | |||||
def metrics(self): | def metrics(self): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def mode(self, model, test=True): | |||||
"""TODO: combine this function with Trainer ?? """ | |||||
if test: | |||||
model.eval() | |||||
else: | |||||
model.train() | |||||
self.eval_history.clear() | |||||
def show_matrices(self): | def show_matrices(self): | ||||
""" | """ | ||||
This is called by Trainer to print evaluation on dev set. | This is called by Trainer to print evaluation on dev set. | ||||
@@ -139,43 +92,21 @@ class POSTester(BaseTester): | |||||
Tester for sequence labeling. | Tester for sequence labeling. | ||||
""" | """ | ||||
def __init__(self, test_args): | |||||
def __init__(self, test_args, action): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(POSTester, self).__init__(test_args) | |||||
super(POSTester, self).__init__(test_args, action) | |||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.batch_result = None | self.batch_result = None | ||||
def data_forward(self, network, inputs): | |||||
"""TODO: combine with Trainer | |||||
:param network: the PyTorch model | |||||
:param x: list of list, [batch_size, max_len] | |||||
:return y: [batch_size, num_classes] | |||||
""" | |||||
# unpack the returned value from make_batch | |||||
if isinstance(inputs, tuple): | |||||
x = inputs[0] | |||||
self.seq_len = inputs[1] | |||||
else: | |||||
x = inputs | |||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
y = network(x) | |||||
return y | |||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
truth = truth.cuda() | truth = truth.cuda() | ||||
loss = self.model.loss(predict, truth, self.seq_len) / self.batch_size | |||||
prediction = self.model.prediction(predict, self.seq_len) | |||||
loss = self.model.loss(predict, truth, self.action.seq_len) / self.batch_size | |||||
prediction = self.model.prediction(predict, self.action.seq_len) | |||||
results = torch.Tensor(prediction).view(-1,) | results = torch.Tensor(prediction).view(-1,) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
results = results.cuda() | results = results.cuda() | ||||
@@ -18,17 +18,15 @@ class BaseTrainer(Action): | |||||
Trainer receives a model and data, and then performs training. | Trainer receives a model and data, and then performs training. | ||||
Subclasses must implement the following abstract methods: | Subclasses must implement the following abstract methods: | ||||
- prepare_input | |||||
- mode | |||||
- define_optimizer | - define_optimizer | ||||
- data_forward | |||||
- grad_backward | - grad_backward | ||||
- get_loss | - get_loss | ||||
""" | """ | ||||
def __init__(self, train_args): | |||||
def __init__(self, train_args, action): | |||||
""" | """ | ||||
:param train_args: dict of (key, value), or dict-like object. key is str. | :param train_args: dict of (key, value), or dict-like object. key is str. | ||||
:param action: an Action object that wrap most operations shared by Trainer, Tester, and Inference. | |||||
The base trainer requires the following keys: | The base trainer requires the following keys: | ||||
- epochs: int, the number of epochs in training | - epochs: int, the number of epochs in training | ||||
@@ -37,6 +35,7 @@ class BaseTrainer(Action): | |||||
- pickle_path: str, the path to pickle files for pre-processing | - pickle_path: str, the path to pickle files for pre-processing | ||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.action = action | |||||
self.n_epochs = train_args["epochs"] | self.n_epochs = train_args["epochs"] | ||||
self.batch_size = train_args["batch_size"] | self.batch_size = train_args["batch_size"] | ||||
self.pickle_path = train_args["pickle_path"] | self.pickle_path = train_args["pickle_path"] | ||||
@@ -72,14 +71,14 @@ class BaseTrainer(Action): | |||||
else: | else: | ||||
self.model = network | self.model = network | ||||
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||||
data_train = self.prepare_input(self.pickle_path) | |||||
# define tester over dev data | # define tester over dev data | ||||
# TODO: more flexible | # TODO: more flexible | ||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda} | "use_cuda": self.use_cuda} | ||||
validator = POSTester(valid_args) | |||||
validator = POSTester(default_valid_args, self.action) | |||||
# main training epochs | # main training epochs | ||||
iterations = len(data_train) // self.batch_size | iterations = len(data_train) // self.batch_size | ||||
@@ -88,14 +87,14 @@ class BaseTrainer(Action): | |||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
# turn on network training mode; define optimizer; prepare batch iterator | # turn on network training mode; define optimizer; prepare batch iterator | ||||
self.mode(test=False) | |||||
self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True)) | |||||
self.action.mode(self.model, test=False) | |||||
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) | |||||
# training iterations in one epoch | # training iterations in one epoch | ||||
for step in range(iterations): | for step in range(iterations): | ||||
batch_x, batch_y = self.make_batch(data_train) | |||||
batch_x, batch_y = self.action.make_batch(iterator, data_train) | |||||
prediction = self.data_forward(network, batch_x) | |||||
prediction = self.action.data_forward(network, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
@@ -105,8 +104,6 @@ class BaseTrainer(Action): | |||||
print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) | print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) | ||||
if self.validate: | if self.validate: | ||||
if data_dev is None: | |||||
raise RuntimeError("No validation data provided.") | |||||
validator.test(network) | validator.test(network) | ||||
if self.save_best_dev and self.best_eval_result(validator): | if self.save_best_dev and self.best_eval_result(validator): | ||||
@@ -118,19 +115,13 @@ class BaseTrainer(Action): | |||||
# finish training | # finish training | ||||
def prepare_input(self, data_path): | |||||
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||||
data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | |||||
embedding = _pickle.load(open(data_path + "embedding.pkl", "rb")) | |||||
return data_train, data_dev, data_test, embedding | |||||
def mode(self, test=False): | |||||
def prepare_input(self, pickle_path): | |||||
""" | """ | ||||
Tell the network to be trained or not. | |||||
:param test: bool | |||||
This is reserved for task-specific processing. | |||||
:param data_path: | |||||
:return: | |||||
""" | """ | ||||
raise NotImplementedError | |||||
return _pickle.load(open(pickle_path + "/data_train.pkl", "rb")) | |||||
def define_optimizer(self): | def define_optimizer(self): | ||||
""" | """ | ||||
@@ -146,17 +137,6 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def data_forward(self, network, x): | |||||
""" | |||||
Forward pass of the data. | |||||
:param network: a model | |||||
:param x: input feature matrix and label vector | |||||
:return: output by the models | |||||
For PyTorch, just do "network(*x)" | |||||
""" | |||||
raise NotImplementedError | |||||
def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
""" | """ | ||||
Compute gradient with link rules. | Compute gradient with link rules. | ||||
@@ -187,50 +167,6 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def make_batch(self, data, output_length=True): | |||||
""" | |||||
1. Perform batching from data and produce a batch of training data. | |||||
2. Add padding. | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
return batch_x and batch_y, if output_length is False | |||||
""" | |||||
indices = next(self.iterator) | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x_pad = self.pad(batch_x) | |||||
batch_y_pad = self.pad(batch_y) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return (batch_x_pad, seq_len), batch_y_pad | |||||
else: | |||||
return batch_x_pad, batch_y_pad | |||||
@staticmethod | |||||
def pad(batch, fill=0): | |||||
""" | |||||
Pad a batch of samples to maximum length. | |||||
:param batch: list of list | |||||
:param fill: word index to pad, default 0. | |||||
:return: a padded batch | |||||
""" | |||||
max_length = max([len(x) for x in batch]) | |||||
for idx, sample in enumerate(batch): | |||||
if len(sample) < max_length: | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | |||||
def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
""" | """ | ||||
:param validator: a Tester instance | :param validator: a Tester instance | ||||
@@ -287,48 +223,14 @@ class POSTrainer(BaseTrainer): | |||||
Trainer for Sequence Modeling | Trainer for Sequence Modeling | ||||
""" | """ | ||||
def __init__(self, train_args): | |||||
super(POSTrainer, self).__init__(train_args) | |||||
def __init__(self, train_args, action): | |||||
super(POSTrainer, self).__init__(train_args, action) | |||||
self.vocab_size = train_args["vocab_size"] | self.vocab_size = train_args["vocab_size"] | ||||
self.num_classes = train_args["num_classes"] | self.num_classes = train_args["num_classes"] | ||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.best_accuracy = 0.0 | self.best_accuracy = 0.0 | ||||
def prepare_input(self, data_path): | |||||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
return data_train, data_dev, 0, 1 | |||||
def data_forward(self, network, inputs): | |||||
""" | |||||
:param network: the PyTorch model | |||||
:param inputs: list of list, [batch_size, max_len], | |||||
or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len] | |||||
:return y: [batch_size, max_len, tag_size] | |||||
""" | |||||
# unpack the returned value from make_batch | |||||
if isinstance(inputs, tuple): | |||||
x = inputs[0] | |||||
self.seq_len = inputs[1] | |||||
else: | |||||
x = inputs | |||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
y = network(x) | |||||
return y | |||||
def mode(self, test=False): | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def define_optimizer(self): | def define_optimizer(self): | ||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | ||||
@@ -349,14 +251,13 @@ class POSTrainer(BaseTrainer): | |||||
truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
truth = truth.cuda() | truth = truth.cuda() | ||||
assert truth.shape == (self.batch_size, self.max_len) | |||||
assert truth.shape == (self.batch_size, self.action.max_len) | |||||
if self.loss_func is None: | if self.loss_func is None: | ||||
if hasattr(self.model, "loss"): | if hasattr(self.model, "loss"): | ||||
self.loss_func = self.model.loss | self.loss_func = self.model.loss | ||||
else: | else: | ||||
self.define_loss() | self.define_loss() | ||||
loss = self.loss_func(predict, truth, self.seq_len) | |||||
# print("loss={:.2f}".format(loss.data)) | |||||
loss = self.loss_func(predict, truth, self.action.seq_len) | |||||
return loss | return loss | ||||
def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
@@ -367,36 +268,6 @@ class POSTrainer(BaseTrainer): | |||||
else: | else: | ||||
return False | return False | ||||
def make_batch(self, data, output_length=True): | |||||
""" | |||||
1. Perform batching from data and produce a batch of training data. | |||||
2. Add padding. | |||||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
E.g. | |||||
[ | |||||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
... | |||||
] | |||||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
return batch_x and batch_y, if output_length is False | |||||
""" | |||||
indices = next(self.iterator) | |||||
batch = [data[idx] for idx in indices] | |||||
batch_x = [sample[0] for sample in batch] | |||||
batch_y = [sample[1] for sample in batch] | |||||
batch_x_pad = self.pad(batch_x) | |||||
batch_y_pad = self.pad(batch_y) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return (batch_x_pad, seq_len), batch_y_pad | |||||
else: | |||||
return batch_x_pad, batch_y_pad | |||||
class LanguageModelTrainer(BaseTrainer): | class LanguageModelTrainer(BaseTrainer): | ||||
""" | """ | ||||
@@ -2,6 +2,7 @@ import sys | |||||
sys.path.append("..") | sys.path.append("..") | ||||
from fastNLP.core.action import SeqLabelAction | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import POSTrainer | from fastNLP.core.trainer import POSTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | ||||
@@ -57,7 +58,7 @@ def infer(): | |||||
print("Inference finished!") | print("Inference finished!") | ||||
def train_test(): | |||||
def train_and_test(): | |||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ||||
@@ -67,12 +68,14 @@ def train_test(): | |||||
train_data = pos_loader.load_lines() | train_data = pos_loader.load_lines() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path) | |||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.5) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
action = SeqLabelAction(train_args) | |||||
# Trainer | # Trainer | ||||
trainer = POSTrainer(train_args) | |||||
trainer = POSTrainer(train_args, action) | |||||
# Model | # Model | ||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
@@ -100,7 +103,7 @@ def train_test(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# Tester | # Tester | ||||
tester = POSTester(test_args) | |||||
tester = POSTester(test_args, action) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | tester.test(model) | ||||
@@ -111,5 +114,5 @@ def train_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_test() | |||||
# infer() | |||||
train_and_test() | |||||