@@ -1,16 +1,111 @@ | |||
""" | |||
This file defines Action(s) and sample methods. | |||
""" | |||
from collections import Counter | |||
import numpy as np | |||
import torch | |||
class Action(object): | |||
""" | |||
base class for Trainer and Tester | |||
Operations shared by Trainer, Tester, and Inference. | |||
This is designed for reducing replicate codes. | |||
- make_batch: produce a min-batch of data. @staticmethod | |||
- pad: padding method used in sequence modeling. @staticmethod | |||
- mode: change network mode for either train or test. (for PyTorch) @staticmethod | |||
The base Action shall define operations shared by as much task-specific Actions as possible. | |||
""" | |||
def __init__(self): | |||
super(Action, self).__init__() | |||
@staticmethod | |||
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None): | |||
"""Batch and Pad data. | |||
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
E.g. | |||
[ | |||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:param use_cuda: bool | |||
:param output_length: whether to output the original length of the sequence before padding. | |||
:param max_len: int, maximum sequence length | |||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
return batch_x and batch_y, if output_length is False | |||
""" | |||
for indices in iterator: | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x = Action.pad(batch_x) | |||
# pad batch_y only if it is a 2-level list | |||
if len(batch_y) > 0 and isinstance(batch_y[0], list): | |||
batch_y = Action.pad(batch_y) | |||
# convert list to tensor | |||
batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||
batch_y = convert_to_torch_tensor(batch_y, use_cuda) | |||
# trim data to max_len | |||
if max_len is not None and batch_x.size(1) > max_len: | |||
batch_x = batch_x[:, :max_len] | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
yield (batch_x, seq_len), batch_y | |||
else: | |||
yield batch_x, batch_y | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
""" | |||
Pad a batch of samples to maximum length of this batch. | |||
:param batch: list of list | |||
:param fill: word index to pad, default 0. | |||
:return: a padded batch | |||
""" | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
return batch | |||
@staticmethod | |||
def mode(model, test=False): | |||
""" | |||
Train mode or Test mode. This is for PyTorch currently. | |||
:param model: | |||
:param test: | |||
""" | |||
if test: | |||
model.eval() | |||
else: | |||
model.train() | |||
def convert_to_torch_tensor(data_list, use_cuda): | |||
""" | |||
convert lists into (cuda) Tensors | |||
:param data_list: 2-level lists | |||
:param use_cuda: bool | |||
:param reqired_grad: bool | |||
:return: PyTorch Tensor of shape [batch_size, max_seq_len] | |||
""" | |||
data_list = torch.Tensor(data_list).long() | |||
if torch.cuda.is_available() and use_cuda: | |||
data_list = data_list.cuda() | |||
return data_list | |||
def k_means_1d(x, k, max_iter=100): | |||
""" | |||
@@ -140,11 +235,10 @@ class Batchifier(object): | |||
def __iter__(self): | |||
batch = [] | |||
while True: | |||
for idx in self.sampler: | |||
batch.append(idx) | |||
if len(batch) == self.batch_size: | |||
yield batch | |||
batch = [] | |||
if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||
for idx in self.sampler: | |||
batch.append(idx) | |||
if len(batch) == self.batch_size: | |||
yield batch | |||
batch = [] | |||
if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||
yield batch |
@@ -1,7 +1,45 @@ | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.action import Batchifier, SequentialSampler | |||
from fastNLP.core.action import convert_to_torch_tensor | |||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||
from fastNLP.modules import utils | |||
def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None): | |||
for indices in iterator: | |||
batch_x = [data[idx] for idx in indices] | |||
batch_x = pad(batch_x) | |||
# convert list to tensor | |||
batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||
# trim data to max_len | |||
if max_len is not None and batch_x.size(1) > max_len: | |||
batch_x = batch_x[:, :max_len] | |||
if min_len is not None and batch_x.size(1) < min_len: | |||
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x) | |||
batch_x = torch.cat((batch_x, pad_tensor), 1) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
yield tuple([batch_x, seq_len]) | |||
else: | |||
yield batch_x | |||
def pad(batch, fill=0): | |||
""" | |||
Pad a batch of samples to maximum length. | |||
:param batch: list of list | |||
:param fill: word index to pad, default 0. | |||
:return: a padded batch | |||
""" | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
return batch | |||
class Inference(object): | |||
@@ -9,7 +47,8 @@ class Inference(object): | |||
This is an interface focusing on predicting output based on trained models. | |||
It does not care about evaluations of the model, which is different from Tester. | |||
This is a high-level model wrapper to be called by FastNLP. | |||
This class does not share any operations with Trainer and Tester. | |||
Currently, Inference does not support GPU. | |||
""" | |||
def __init__(self, pickle_path): | |||
@@ -32,13 +71,11 @@ class Inference(object): | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
self.batch_output.clear() | |||
self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||
num_iter = len(data) // self.batch_size | |||
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||
for step in range(num_iter): | |||
batch_x = self.make_batch(data) | |||
for batch_x in self.make_batch(iterator, data, use_cuda=False): | |||
prediction = self.data_forward(network, batch_x) | |||
@@ -51,43 +88,12 @@ class Inference(object): | |||
network.eval() | |||
else: | |||
network.train() | |||
self.batch_output.clear() | |||
def data_forward(self, network, x): | |||
""" | |||
This is only for sequence labeling with CRF decoder. TODO: more general ? | |||
:param network: | |||
:param x: | |||
:return: | |||
""" | |||
seq_len = [len(seq) for seq in x] | |||
x = torch.Tensor(x).long() | |||
y = network(x) | |||
prediction = network.prediction(y, seq_len) | |||
# To do: hide framework | |||
results = torch.Tensor(prediction).view(-1, ) | |||
return list(results.data) | |||
raise NotImplementedError | |||
def make_batch(self, data): | |||
indices = next(self.iterator) | |||
batch_x = [data[idx] for idx in indices] | |||
if self.batch_size > 1: | |||
batch_x = self.pad(batch_x) | |||
return batch_x | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
""" | |||
Pad a batch of samples to maximum length. | |||
:param batch: list of list | |||
:param fill: word index to pad, default 0. | |||
:return: a padded batch | |||
""" | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + [fill * (max_length - len(sample))] | |||
return batch | |||
def make_batch(self, iterator, data, use_cuda): | |||
raise NotImplementedError | |||
def prepare_input(self, data): | |||
""" | |||
@@ -106,13 +112,76 @@ class Inference(object): | |||
data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | |||
return data_index | |||
def prepare_output(self, data): | |||
raise NotImplementedError | |||
class SeqLabelInfer(Inference): | |||
""" | |||
Inference on sequence labeling models. | |||
""" | |||
def __init__(self, pickle_path): | |||
super(SeqLabelInfer, self).__init__(pickle_path) | |||
def data_forward(self, network, inputs): | |||
""" | |||
This is only for sequence labeling with CRF decoder. | |||
:param network: | |||
:param inputs: | |||
:return: Tensor | |||
""" | |||
if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | |||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
# unpack the returned value from make_batch | |||
x, seq_len = inputs[0], inputs[1] | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = utils.seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
y = network(x) | |||
prediction = network.prediction(y, mask) | |||
return torch.Tensor(prediction, required_grad=False) | |||
def make_batch(self, iterator, data, use_cuda): | |||
return make_batch(iterator, data, use_cuda, output_length=True) | |||
def prepare_output(self, batch_outputs): | |||
""" | |||
Transform list of batch outputs into strings. | |||
:param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor. | |||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||
:return: | |||
""" | |||
results = [] | |||
for batch in batch_outputs: | |||
results.append([self.index2label[int(x.data)] for x in batch]) | |||
for example in np.array(batch): | |||
results.append([self.index2label[int(x)] for x in example]) | |||
return results | |||
class ClassificationInfer(Inference): | |||
""" | |||
Inference on Classification models. | |||
""" | |||
def __init__(self, pickle_path): | |||
super(ClassificationInfer, self).__init__(pickle_path) | |||
def data_forward(self, network, x): | |||
"""Forward through network.""" | |||
logits = network(x) | |||
return logits | |||
def make_batch(self, iterator, data, use_cuda): | |||
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5) | |||
def prepare_output(self, batch_outputs): | |||
""" | |||
Transform list of batch outputs into strings. | |||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. | |||
:return: | |||
""" | |||
results = [] | |||
for batch_out in batch_outputs: | |||
idx = np.argmax(batch_out.detach().numpy()) | |||
results.append(self.index2label[idx]) | |||
return results |
@@ -1,14 +1,14 @@ | |||
import _pickle | |||
import os | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
from fastNLP.modules import utils | |||
class BaseTester(Action): | |||
class BaseTester(object): | |||
"""docstring for Tester""" | |||
def __init__(self, test_args): | |||
@@ -37,25 +37,33 @@ class BaseTester(Action): | |||
else: | |||
self.model = network | |||
# no backward setting for model | |||
for param in network.parameters(): | |||
param.requires_grad = False | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
self.eval_history.clear() | |||
self.batch_output.clear() | |||
dev_data = self.prepare_input(self.pickle_path) | |||
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||
num_iter = len(dev_data) // self.batch_size | |||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||
n_batches = len(dev_data) // self.batch_size | |||
n_print = 1 | |||
step = 0 | |||
for step in range(num_iter): | |||
batch_x, batch_y = self.make_batch(dev_data) | |||
for batch_x, batch_y in self.make_batch(iterator, dev_data): | |||
prediction = self.data_forward(network, batch_x) | |||
eval_results = self.evaluate(prediction, batch_y) | |||
if self.save_output: | |||
self.batch_output.append(prediction) | |||
if self.save_loss: | |||
self.eval_history.append(eval_results) | |||
step += 1 | |||
def prepare_input(self, data_path): | |||
""" | |||
@@ -64,51 +72,14 @@ class BaseTester(Action): | |||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | |||
""" | |||
if self.save_dev_data is None: | |||
data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
self.save_dev_data = data_dev | |||
return self.save_dev_data | |||
def make_batch(self, data, output_length=True): | |||
""" | |||
1. Perform batching from data and produce a batch of training data. | |||
2. Add padding. | |||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
E.g. | |||
[ | |||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
""" | |||
indices = next(self.iterator) | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y_pad | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
""" | |||
Pad a batch of samples to maximum length. | |||
:param batch: list of list | |||
:param fill: word index to pad, default 0. | |||
:return: a padded batch | |||
""" | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
return batch | |||
def mode(self, model, test): | |||
Action.mode(model, test) | |||
def data_forward(self, network, data): | |||
def data_forward(self, network, x): | |||
raise NotImplementedError | |||
def evaluate(self, predict, truth): | |||
@@ -118,14 +89,6 @@ class BaseTester(Action): | |||
def metrics(self): | |||
raise NotImplementedError | |||
def mode(self, model, test=True): | |||
"""TODO: combine this function with Trainer ?? """ | |||
if test: | |||
model.eval() | |||
else: | |||
model.train() | |||
self.eval_history.clear() | |||
def show_matrices(self): | |||
""" | |||
This is called by Trainer to print evaluation on dev set. | |||
@@ -133,8 +96,11 @@ class BaseTester(Action): | |||
""" | |||
raise NotImplementedError | |||
def make_batch(self, iterator, data): | |||
raise NotImplementedError | |||
class POSTester(BaseTester): | |||
class SeqLabelTester(BaseTester): | |||
""" | |||
Tester for sequence labeling. | |||
""" | |||
@@ -143,44 +109,36 @@ class POSTester(BaseTester): | |||
""" | |||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
""" | |||
super(POSTester, self).__init__(test_args) | |||
super(SeqLabelTester, self).__init__(test_args) | |||
self.max_len = None | |||
self.mask = None | |||
self.batch_result = None | |||
def data_forward(self, network, inputs): | |||
"""TODO: combine with Trainer | |||
:param network: the PyTorch model | |||
:param x: list of list, [batch_size, max_len] | |||
:return y: [batch_size, num_classes] | |||
""" | |||
if not isinstance(inputs, tuple): | |||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
# unpack the returned value from make_batch | |||
if isinstance(inputs, tuple): | |||
x = inputs[0] | |||
self.seq_len = inputs[1] | |||
else: | |||
x = inputs | |||
x = torch.Tensor(x).long() | |||
x, seq_len = inputs[0], inputs[1] | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = utils.seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
x = x.cuda() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
mask = mask.cuda() | |||
self.mask = mask | |||
y = network(x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
truth = torch.Tensor(truth) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
truth = truth.cuda() | |||
loss = self.model.loss(predict, truth, self.seq_len) / self.batch_size | |||
prediction = self.model.prediction(predict, self.seq_len) | |||
batch_size, max_len = predict.size(0), predict.size(1) | |||
loss = self.model.loss(predict, truth, self.mask) / batch_size | |||
prediction = self.model.prediction(predict, self.mask) | |||
results = torch.Tensor(prediction).view(-1,) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
results = results.cuda() | |||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | |||
return [loss.data, accuracy] | |||
# make sure "results" is in the same device as "truth" | |||
results = results.to(truth) | |||
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0] | |||
return [loss.data, accuracy.data] | |||
def metrics(self): | |||
batch_loss = np.mean([x[0] for x in self.eval_history]) | |||
@@ -195,8 +153,11 @@ class POSTester(BaseTester): | |||
loss, accuracy = self.metrics() | |||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
def make_batch(self, iterator, data): | |||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) | |||
class ClassTester(BaseTester): | |||
class ClassificationTester(BaseTester): | |||
"""Tester for classification.""" | |||
def __init__(self, test_args): | |||
@@ -204,7 +165,7 @@ class ClassTester(BaseTester): | |||
:param test_args: a dict-like object that has __getitem__ method, \ | |||
can be accessed by "test_args["key_str"]" | |||
""" | |||
# super(ClassTester, self).__init__() | |||
super(ClassificationTester, self).__init__(test_args) | |||
self.pickle_path = test_args["pickle_path"] | |||
self.save_dev_data = None | |||
@@ -212,111 +173,8 @@ class ClassTester(BaseTester): | |||
self.mean_loss = None | |||
self.iterator = None | |||
if "test_name" in test_args: | |||
self.test_name = test_args["test_name"] | |||
else: | |||
self.test_name = "data_test.pkl" | |||
if "validate_in_training" in test_args: | |||
self.validate_in_training = test_args["validate_in_training"] | |||
else: | |||
self.validate_in_training = False | |||
if "save_output" in test_args: | |||
self.save_output = test_args["save_output"] | |||
else: | |||
self.save_output = False | |||
if "save_loss" in test_args: | |||
self.save_loss = test_args["save_loss"] | |||
else: | |||
self.save_loss = True | |||
if "batch_size" in test_args: | |||
self.batch_size = test_args["batch_size"] | |||
else: | |||
self.batch_size = 50 | |||
if "use_cuda" in test_args: | |||
self.use_cuda = test_args["use_cuda"] | |||
else: | |||
self.use_cuda = True | |||
if "max_len" in test_args: | |||
self.max_len = test_args["max_len"] | |||
else: | |||
self.max_len = None | |||
self.model = None | |||
self.eval_history = [] | |||
self.batch_output = [] | |||
def test(self, network): | |||
# prepare model | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self.model = network.cuda() | |||
else: | |||
self.model = network | |||
# no backward setting for model | |||
for param in self.model.parameters(): | |||
param.requires_grad = False | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
# prepare test data | |||
data_test = self.prepare_input(self.pickle_path, self.test_name) | |||
# data generator | |||
self.iterator = iter(Batchifier( | |||
RandomSampler(data_test), self.batch_size, drop_last=False)) | |||
# test | |||
n_batches = len(data_test) // self.batch_size | |||
n_print = n_batches // 10 | |||
step = 0 | |||
for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len): | |||
prediction = self.data_forward(network, batch_x) | |||
eval_results = self.evaluate(prediction, batch_y) | |||
if self.save_output: | |||
self.batch_output.append(prediction) | |||
if self.save_loss: | |||
self.eval_history.append(eval_results) | |||
if step % n_print == 0: | |||
print("step: {:>5}".format(step)) | |||
step += 1 | |||
def prepare_input(self, data_dir, file_name): | |||
"""Prepare data.""" | |||
file_path = os.path.join(data_dir, file_name) | |||
with open(file_path, 'rb') as f: | |||
data = _pickle.load(f) | |||
return data | |||
def make_batch(self, data, max_len=None): | |||
"""Batch and pad data.""" | |||
for indices in self.iterator: | |||
# generate batch and pad | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x = self.pad(batch_x) | |||
# convert to tensor | |||
batch_x = torch.tensor(batch_x, dtype=torch.long) | |||
batch_y = torch.tensor(batch_y, dtype=torch.long) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
batch_x = batch_x.cuda() | |||
batch_y = batch_y.cuda() | |||
# trim data to max_len | |||
if max_len is not None and batch_x.size(1) > max_len: | |||
batch_x = batch_x[:, :max_len] | |||
yield batch_x, batch_y | |||
def make_batch(self, iterator, data, max_len=None): | |||
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len) | |||
def data_forward(self, network, x): | |||
"""Forward through network.""" | |||
@@ -337,10 +195,3 @@ class ClassTester(BaseTester): | |||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||
def mode(self, model, test=True): | |||
"""TODO: combine this function with Trainer ?? """ | |||
if test: | |||
model.eval() | |||
else: | |||
model.train() | |||
self.eval_history.clear() |
@@ -8,20 +8,18 @@ import torch | |||
import torch.nn as nn | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||
from fastNLP.modules import utils | |||
from fastNLP.saver.model_saver import ModelSaver | |||
class BaseTrainer(Action): | |||
class BaseTrainer(object): | |||
"""Base trainer for all trainers. | |||
Trainer receives a model and data, and then performs training. | |||
Subclasses must implement the following abstract methods: | |||
- prepare_input | |||
- mode | |||
- define_optimizer | |||
- data_forward | |||
- grad_backward | |||
- get_loss | |||
""" | |||
@@ -75,25 +73,29 @@ class BaseTrainer(Action): | |||
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||
# define tester over dev data | |||
# TODO: more flexible | |||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda} | |||
validator = POSTester(valid_args) | |||
if self.validate: | |||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda} | |||
validator = self._create_validator(default_valid_args) | |||
# main training epochs | |||
iterations = len(data_train) // self.batch_size | |||
self.define_optimizer() | |||
# main training epochs | |||
start = time() | |||
n_samples = len(data_train) | |||
n_batches = n_samples // self.batch_size | |||
n_print = 1 | |||
for epoch in range(1, self.n_epochs + 1): | |||
# turn on network training mode; define optimizer; prepare batch iterator | |||
self.mode(test=False) | |||
self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True)) | |||
# turn on network training mode; prepare batch iterator | |||
self.mode(network, test=False) | |||
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||
# training iterations in one epoch | |||
for step in range(iterations): | |||
batch_x, batch_y = self.make_batch(data_train) | |||
step = 0 | |||
for batch_x, batch_y in self.make_batch(iterator, data_train): | |||
prediction = self.data_forward(network, batch_x) | |||
@@ -101,12 +103,14 @@ class BaseTrainer(Action): | |||
self.grad_backward(loss) | |||
self.update() | |||
if step % 10 == 0: | |||
print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) | |||
if step % n_print == 0: | |||
end = time() | |||
diff = timedelta(seconds=round(end - start)) | |||
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||
epoch, step, loss.data, diff)) | |||
step += 1 | |||
if self.validate: | |||
if data_dev is None: | |||
raise RuntimeError("No validation data provided.") | |||
validator.test(network) | |||
if self.save_best_dev and self.best_eval_result(validator): | |||
@@ -116,22 +120,32 @@ class BaseTrainer(Action): | |||
print("[epoch {}]".format(epoch), end=" ") | |||
print(validator.show_matrices()) | |||
# finish training | |||
def prepare_input(self, data_path): | |||
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | |||
embedding = _pickle.load(open(data_path + "embedding.pkl", "rb")) | |||
return data_train, data_dev, data_test, embedding | |||
def mode(self, test=False): | |||
def prepare_input(self, pickle_path): | |||
""" | |||
Tell the network to be trained or not. | |||
:param test: bool | |||
For task-specific processing. | |||
:param pickle_path: | |||
:return data_train, data_dev, data_test, embedding: | |||
""" | |||
names = [ | |||
"data_train.pkl", "data_dev.pkl", | |||
"data_test.pkl", "embedding.pkl"] | |||
files = [] | |||
for name in names: | |||
file_path = os.path.join(pickle_path, name) | |||
if os.path.exists(file_path): | |||
with open(file_path, 'rb') as f: | |||
data = _pickle.load(f) | |||
else: | |||
data = [] | |||
files.append(data) | |||
return tuple(files) | |||
def make_batch(self, iterator, data): | |||
raise NotImplementedError | |||
def mode(self, network, test): | |||
Action.mode(network, test) | |||
def define_optimizer(self): | |||
""" | |||
Define framework-specific optimizer specified by the models. | |||
@@ -147,14 +161,6 @@ class BaseTrainer(Action): | |||
raise NotImplementedError | |||
def data_forward(self, network, x): | |||
""" | |||
Forward pass of the data. | |||
:param network: a model | |||
:param x: input feature matrix and label vector | |||
:return: output by the models | |||
For PyTorch, just do "network(*x)" | |||
""" | |||
raise NotImplementedError | |||
def grad_backward(self, loss): | |||
@@ -187,50 +193,6 @@ class BaseTrainer(Action): | |||
""" | |||
raise NotImplementedError | |||
def make_batch(self, data, output_length=True): | |||
""" | |||
1. Perform batching from data and produce a batch of training data. | |||
2. Add padding. | |||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
E.g. | |||
[ | |||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
return batch_x and batch_y, if output_length is False | |||
""" | |||
indices = next(self.iterator) | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y_pad | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
""" | |||
Pad a batch of samples to maximum length. | |||
:param batch: list of list | |||
:param fill: word index to pad, default 0. | |||
:return: a padded batch | |||
""" | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
return batch | |||
def best_eval_result(self, validator): | |||
""" | |||
:param validator: a Tester instance | |||
@@ -245,6 +207,9 @@ class BaseTrainer(Action): | |||
""" | |||
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | |||
def _create_validator(self, valid_args): | |||
raise NotImplementedError | |||
class ToyTrainer(BaseTrainer): | |||
""" | |||
@@ -259,12 +224,6 @@ class ToyTrainer(BaseTrainer): | |||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
return data_train, data_dev, 0, 1 | |||
def mode(self, test=False): | |||
if test: | |||
self.model.eval() | |||
else: | |||
self.model.train() | |||
def data_forward(self, network, x): | |||
return network(x) | |||
@@ -282,53 +241,20 @@ class ToyTrainer(BaseTrainer): | |||
self.optimizer.step() | |||
class POSTrainer(BaseTrainer): | |||
class SeqLabelTrainer(BaseTrainer): | |||
""" | |||
Trainer for Sequence Modeling | |||
""" | |||
def __init__(self, train_args): | |||
super(POSTrainer, self).__init__(train_args) | |||
super(SeqLabelTrainer, self).__init__(train_args) | |||
self.vocab_size = train_args["vocab_size"] | |||
self.num_classes = train_args["num_classes"] | |||
self.max_len = None | |||
self.mask = None | |||
self.best_accuracy = 0.0 | |||
def prepare_input(self, data_path): | |||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
return data_train, data_dev, 0, 1 | |||
def data_forward(self, network, inputs): | |||
""" | |||
:param network: the PyTorch model | |||
:param inputs: list of list, [batch_size, max_len], | |||
or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len] | |||
:return y: [batch_size, max_len, tag_size] | |||
""" | |||
# unpack the returned value from make_batch | |||
if isinstance(inputs, tuple): | |||
x = inputs[0] | |||
self.seq_len = inputs[1] | |||
else: | |||
x = inputs | |||
x = torch.Tensor(x).long() | |||
if torch.cuda.is_available() and self.use_cuda: | |||
x = x.cuda() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
y = network(x) | |||
return y | |||
def mode(self, test=False): | |||
if test: | |||
self.model.eval() | |||
else: | |||
self.model.train() | |||
def define_optimizer(self): | |||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9) | |||
@@ -339,6 +265,23 @@ class POSTrainer(BaseTrainer): | |||
def update(self): | |||
self.optimizer.step() | |||
def data_forward(self, network, inputs): | |||
if not isinstance(inputs, tuple): | |||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||
# unpack the returned value from make_batch | |||
x, seq_len = inputs[0], inputs[1] | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = utils.seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
mask = mask.cuda() | |||
self.mask = mask | |||
y = network(x) | |||
return y | |||
def get_loss(self, predict, truth): | |||
""" | |||
Compute loss given prediction and ground truth. | |||
@@ -346,17 +289,10 @@ class POSTrainer(BaseTrainer): | |||
:param truth: ground truth label vector, [batch_size, max_len] | |||
:return: a scalar | |||
""" | |||
truth = torch.Tensor(truth) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
truth = truth.cuda() | |||
assert truth.shape == (self.batch_size, self.max_len) | |||
if self.loss_func is None: | |||
if hasattr(self.model, "loss"): | |||
self.loss_func = self.model.loss | |||
else: | |||
self.define_loss() | |||
loss = self.loss_func(predict, truth, self.seq_len) | |||
# print("loss={:.2f}".format(loss.data)) | |||
batch_size, max_len = predict.size(0), predict.size(1) | |||
assert truth.shape == (batch_size, max_len) | |||
loss = self.model.loss(predict, truth, self.mask) | |||
return loss | |||
def best_eval_result(self, validator): | |||
@@ -367,62 +303,18 @@ class POSTrainer(BaseTrainer): | |||
else: | |||
return False | |||
def make_batch(self, data, output_length=True): | |||
""" | |||
1. Perform batching from data and produce a batch of training data. | |||
2. Add padding. | |||
:param data: list. Each entry is a sample, which is also a list of features and label(s). | |||
E.g. | |||
[ | |||
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
... | |||
] | |||
:return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||
seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||
return batch_x and batch_y, if output_length is False | |||
""" | |||
indices = next(self.iterator) | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y_pad | |||
def make_batch(self, iterator, data): | |||
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda) | |||
class LanguageModelTrainer(BaseTrainer): | |||
""" | |||
Trainer for Language Model | |||
""" | |||
def _create_validator(self, valid_args): | |||
return SeqLabelTester(valid_args) | |||
def __init__(self, train_args): | |||
super(LanguageModelTrainer, self).__init__(train_args) | |||
def prepare_input(self, data_path): | |||
pass | |||
class ClassTrainer(BaseTrainer): | |||
class ClassificationTrainer(BaseTrainer): | |||
"""Trainer for classification.""" | |||
def __init__(self, train_args): | |||
# super(ClassTrainer, self).__init__(train_args) | |||
self.n_epochs = train_args["epochs"] | |||
self.batch_size = train_args["batch_size"] | |||
self.pickle_path = train_args["pickle_path"] | |||
if "validate" in train_args: | |||
self.validate = train_args["validate"] | |||
else: | |||
self.validate = False | |||
super(ClassificationTrainer, self).__init__(train_args) | |||
if "learn_rate" in train_args: | |||
self.learn_rate = train_args["learn_rate"] | |||
else: | |||
@@ -431,127 +323,14 @@ class ClassTrainer(BaseTrainer): | |||
self.momentum = train_args["momentum"] | |||
else: | |||
self.momentum = 0.9 | |||
if "use_cuda" in train_args: | |||
self.use_cuda = train_args["use_cuda"] | |||
else: | |||
self.use_cuda = True | |||
self.model = None | |||
self.iterator = None | |||
self.loss_func = None | |||
self.optimizer = None | |||
def train(self, network): | |||
"""General Training Steps | |||
:param network: a model | |||
The method is framework independent. | |||
Work by calling the following methods: | |||
- prepare_input | |||
- mode | |||
- define_optimizer | |||
- data_forward | |||
- get_loss | |||
- grad_backward | |||
- update | |||
Subclasses must implement these methods with a specific framework. | |||
""" | |||
# prepare model and data, transfer model to gpu if available | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self.model = network.cuda() | |||
else: | |||
self.model = network | |||
data_train, data_dev, data_test, embedding = self.prepare_input( | |||
self.pickle_path) | |||
# define tester over dev data | |||
# valid_args = { | |||
# "save_output": True, "validate_in_training": True, | |||
# "save_dev_input": True, "save_loss": True, | |||
# "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||
# validator = POSTester(valid_args) | |||
# urn on network training mode, define loss and optimizer | |||
self.define_loss() | |||
self.define_optimizer() | |||
self.mode(test=False) | |||
# main training epochs | |||
start = time() | |||
n_samples = len(data_train) | |||
n_batches = n_samples // self.batch_size | |||
n_print = n_batches // 10 | |||
for epoch in range(self.n_epochs): | |||
# prepare batch iterator | |||
self.iterator = iter(Batchifier( | |||
RandomSampler(data_train), self.batch_size, drop_last=False)) | |||
# training iterations in one epoch | |||
step = 0 | |||
for batch_x, batch_y in self.make_batch(data_train): | |||
prediction = self.data_forward(network, batch_x) | |||
loss = self.get_loss(prediction, batch_y) | |||
self.grad_backward(loss) | |||
self.update() | |||
if step % n_print == 0: | |||
acc = self.get_acc(prediction, batch_y) | |||
end = time() | |||
diff = timedelta(seconds=round(end - start)) | |||
print("epoch: {:>3} step: {:>4} loss: {:>4.2}" | |||
" train acc: {:>5.1%} time: {}".format( | |||
epoch, step, loss, acc, diff)) | |||
step += 1 | |||
# if self.validate: | |||
# if data_dev is None: | |||
# raise RuntimeError("No validation data provided.") | |||
# validator.test(network) | |||
# print("[epoch {}]".format(epoch), end=" ") | |||
# print(validator.show_matrices()) | |||
# finish training | |||
def prepare_input(self, data_path): | |||
names = [ | |||
"data_train.pkl", "data_dev.pkl", | |||
"data_test.pkl", "embedding.pkl"] | |||
files = [] | |||
for name in names: | |||
file_path = os.path.join(data_path, name) | |||
if os.path.exists(file_path): | |||
with open(file_path, 'rb') as f: | |||
data = _pickle.load(f) | |||
else: | |||
data = [] | |||
files.append(data) | |||
return tuple(files) | |||
def mode(self, test=False): | |||
""" | |||
Tell the network to be trained or not. | |||
:param test: bool | |||
""" | |||
if test: | |||
self.model.eval() | |||
else: | |||
self.model.train() | |||
self.best_accuracy = 0 | |||
def define_loss(self): | |||
""" | |||
Assign an instance of loss function to self.loss_func | |||
E.g. self.loss_func = nn.CrossEntropyLoss() | |||
""" | |||
if self.loss_func is None: | |||
if hasattr(self.model, "loss"): | |||
self.loss_func = self.model.loss | |||
else: | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.loss_func = nn.CrossEntropyLoss() | |||
def define_optimizer(self): | |||
""" | |||
@@ -567,10 +346,6 @@ class ClassTrainer(BaseTrainer): | |||
logits = network(x) | |||
return logits | |||
def get_loss(self, predict, truth): | |||
"""Calculate loss.""" | |||
return self.loss_func(predict, truth) | |||
def grad_backward(self, loss): | |||
"""Compute gradient backward.""" | |||
self.model.zero_grad() | |||
@@ -580,30 +355,21 @@ class ClassTrainer(BaseTrainer): | |||
"""Apply gradient.""" | |||
self.optimizer.step() | |||
def make_batch(self, data): | |||
"""Batch and pad data.""" | |||
for indices in self.iterator: | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x = self.pad(batch_x) | |||
batch_x = torch.tensor(batch_x, dtype=torch.long) | |||
batch_y = torch.tensor(batch_y, dtype=torch.long) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
batch_x = batch_x.cuda() | |||
batch_y = batch_y.cuda() | |||
yield batch_x, batch_y | |||
def make_batch(self, iterator, data): | |||
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda) | |||
def get_acc(self, y_logit, y_true): | |||
"""Compute accuracy.""" | |||
y_pred = torch.argmax(y_logit, dim=-1) | |||
return int(torch.sum(y_true == y_pred)) / len(y_true) | |||
def best_eval_result(self, validator): | |||
_, _, accuracy = validator.metrics() | |||
if accuracy > self.best_accuracy: | |||
self.best_accuracy = accuracy | |||
return True | |||
else: | |||
return False | |||
if __name__ == "__name__": | |||
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | |||
trainer = BaseTrainer(train_args) | |||
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | |||
trainer.make_batch(data=data_train) | |||
def _create_validator(self, valid_args): | |||
return ClassificationTester(valid_args) |
@@ -1,13 +1,14 @@ | |||
# python: 3.6 | |||
# encoding: utf-8 | |||
import torch | |||
import torch.nn as nn | |||
# import torch.nn.functional as F | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | |||
class CNNText(BaseModel): | |||
class CNNText(torch.nn.Module): | |||
""" | |||
Text classification model by character CNN, the implementation of paper | |||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence | |||
@@ -1,7 +1,7 @@ | |||
import torch | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder, utils | |||
from fastNLP.modules import decoder, encoder | |||
class SeqLabeling(BaseModel): | |||
@@ -34,46 +34,25 @@ class SeqLabeling(BaseModel): | |||
# [batch_size, max_len, num_classes] | |||
return x | |||
def loss(self, x, y, seq_length): | |||
def loss(self, x, y, mask): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||
:param y: LongTensor, [batch_size, max_len] | |||
:param seq_length: list of int. [batch_size] | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
:param mask: ByteTensor, [batch_size, ,max_len] | |||
:return loss: a scalar Tensor | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
batch_size = x.size(0) | |||
max_len = x.size(1) | |||
mask = utils.seq_mask(seq_length, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
# TODO: remove | |||
if torch.cuda.is_available(): | |||
mask = mask.cuda() | |||
# mask = x.new(batch_size, max_len) | |||
total_loss = self.Crf(x, y, mask) | |||
return torch.mean(total_loss) | |||
def prediction(self, x, seq_length): | |||
def prediction(self, x, mask): | |||
""" | |||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||
:param seq_length: int | |||
:return prediction: list of tuple of (decode path(list), best score) | |||
:param mask: ByteTensor, [batch_size, max_len] | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
x = x.float() | |||
max_len = x.size(1) | |||
mask = utils.seq_mask(seq_length, max_len) | |||
# hack: make sure mask has the same device as x | |||
mask = mask.to(x).byte() | |||
tag_seq = self.Crf.viterbi_decode(x, mask) | |||
return tag_seq |
@@ -132,6 +132,7 @@ class ConditionalRandomField(nn.Module): | |||
Given a feats matrix, return best decode path and best score. | |||
:param feats: | |||
:param masks: | |||
:param get_score: bool, whether to output the decode score. | |||
:return:List[Tuple(List, float)], | |||
""" | |||
batch_size, max_len, tag_size = feats.size() | |||
@@ -3,12 +3,12 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.core.trainer import POSTrainer | |||
from fastNLP.core.trainer import SeqLabelTrainer | |||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.core.tester import SeqLabelTester | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.core.inference import Inference | |||
@@ -64,7 +64,7 @@ def train(): | |||
train_args["num_classes"] = p.num_classes | |||
# Trainer | |||
trainer = POSTrainer(train_args) | |||
trainer = SeqLabelTrainer(train_args) | |||
# Model | |||
model = SeqLabeling(train_args) | |||
@@ -96,7 +96,7 @@ def test(): | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
# Tester | |||
tester = POSTester(test_args) | |||
tester = SeqLabelTester(test_args) | |||
# Start testing | |||
tester.test(model) | |||
@@ -64,4 +64,90 @@ | |||
3 B-t | |||
1 M-t | |||
日 E-t | |||
, S-w | |||
, S-w | |||
迈 B-v | |||
向 E-v | |||
充 B-v | |||
满 E-v | |||
希 B-n | |||
望 E-n | |||
的 S-u | |||
新 S-a | |||
世 B-n | |||
纪 E-n | |||
— B-w | |||
— E-w | |||
一 B-t | |||
九 M-t | |||
九 M-t | |||
八 M-t | |||
年 E-t | |||
新 B-t | |||
年 E-t | |||
讲 B-n | |||
话 E-n | |||
( S-w | |||
附 S-v | |||
图 B-n | |||
片 E-n | |||
1 S-m | |||
张 S-q | |||
) S-w | |||
迈 B-v | |||
向 E-v | |||
充 B-v | |||
满 E-v | |||
希 B-n | |||
望 E-n | |||
的 S-u | |||
新 S-a | |||
世 B-n | |||
纪 E-n | |||
— B-w | |||
— E-w | |||
一 B-t | |||
九 M-t | |||
九 M-t | |||
八 M-t | |||
年 E-t | |||
新 B-t | |||
年 E-t | |||
讲 B-n | |||
话 E-n | |||
( S-w | |||
附 S-v | |||
图 B-n | |||
片 E-n | |||
1 S-m | |||
张 S-q | |||
) S-w | |||
迈 B-v | |||
向 E-v | |||
充 B-v | |||
满 E-v | |||
希 B-n | |||
望 E-n | |||
的 S-u | |||
新 S-a | |||
世 B-n | |||
纪 E-n | |||
— B-w | |||
— E-w | |||
一 B-t | |||
九 M-t | |||
九 M-t | |||
八 M-t | |||
年 E-t | |||
新 B-t | |||
年 E-t | |||
讲 B-n | |||
话 E-n | |||
( S-w | |||
附 S-v | |||
图 B-n | |||
片 E-n | |||
1 S-m | |||
张 S-q | |||
) S-w |
@@ -0,0 +1,100 @@ | |||
entertainment 台 媒 预 测 周 冬 雨 金 马 奖 封 后 , 大 气 的 倪 妮 却 佳 作 难 出 | |||
food 农 村 就 是 好 , 能 吃 到 纯 天 然 无 添 加 的 野 生 蜂 蜜 , 营 养 又 健 康 | |||
fashion 1 4 款 知 性 美 装 , 时 尚 惊 艳 搁 浅 的 阳 光 轻 熟 的 优 雅 | |||
history 火 焰 喷 射 器 1 0 0 0 度 火 焰 烧 死 鬼 子 4 连 拍 | |||
society 1 8 岁 青 年 砍 死 8 8 岁 老 兵 | |||
fashion 醋 洗 脸 的 正 确 方 法 洗 对 了 不 仅 美 容 肌 肤 还 能 收 缩 毛 孔 | |||
game 大 家 都 说 说 除 了 这 1 0 个 英 雄 , L O L 还 有 哪 些 英 雄 可 以 单 挑 男 爵 | |||
sports 王 仕 鹏 退 役 担 任 N B A 总 决 赛 现 场 解 说 嘉 宾 | |||
regimen 天 天 吃 “ 洋 快 餐 ” , 5 岁 女 童 患 上 肝 炎 | |||
food 汤 里 的 蛋 花 怎 样 才 能 如 花 朵 般 漂 亮 , 注 意 这 一 点 即 可 ! | |||
tech 英 退 休 人 士 把 谷 歌 当 活 人 以 礼 貌 搜 索 请 求 征 服 整 个 互 联 网 | |||
discovery N A S A 探 测 器 拍 摄 地 球 、 火 星 和 冥 王 星 合 影 | |||
society 当 骗 子 遇 上 撒 贝 宁 ! 几 句 话 过 后 骗 子 赔 礼 道 歉 . . . . . | |||
history 红 军 长 征 在 中 国 革 命 史 上 的 地 位 | |||
world 实 拍 神 秘 之 国 , 带 你 走 进 真 实 的 朝 鲜 | |||
tech 逼 格 爆 表 ! 古 文 版 2 0 1 6 网 络 热 词 : 燃 尽 洪 荒 之 力 | |||
story 因 为 一 样 东 西 这 个 后 娘 竟 然 给 孩 子 磕 头 | |||
game L O L : 皮 肤 对 操 作 没 影 响 ? 细 数 那 些 有 加 成 效 果 的 皮 肤 | |||
fashion 冬 天 想 穿 裙 子 又 怕 冷 ? 学 了 这 些 搭 配 就 能 好 看 又 温 暖 ! | |||
entertainment 贾 建 军 少 林 三 光 剑 视 频 | |||
food 再 也 不 用 出 去 吃 羊 肉 串 , 自 己 做 又 卫 生 又 健 康 | |||
regimen 男 人 多 吃 这 几 道 菜 , 效 果 胜 “ 伟 哥 ” | |||
baby 宝 贝 厨 房 丨 肉 类 辅 食 第 一 步 宝 宝 的 生 长 发 育 每 天 都 离 不 开 它 ! | |||
travel 近 8 0 亿 的 顶 级 豪 华 邮 轮 上 到 底 有 什 么 ? | |||
sports 厄 齐 尔 心 中 最 想 签 约 的 三 个 人 | |||
food 东 北 的 粘 豆 包 啊 , 想 死 你 们 了 ! | |||
military 强 军 足 音 | |||
sports 奥 运 赛 场 上 , 被 喷 子 痛 批 的 十 大 知 名 运 动 员 | |||
game 老 玩 家 分 享 对 2 0 1 6 L P L 夏 季 赛 R N G 的 分 析 | |||
military 揭 秘 : 关 于 战 争 的 五 大 真 相 , 不 要 再 被 影 视 所 欺 骗 了 ! | |||
food 小 丫 厨 房 : 夏 天 怎 么 吃 辣 不 长 痘 ? 告 诉 你 火 锅 鸡 、 香 辣 鱼 的 正 确 做 法 | |||
travel 中 国 首 个 内 陆 城 市 群 上 的 9 座 城 市 , 看 看 有 你 的 家 乡 吗 | |||
fashion 李 小 璐 做 榜 样 接 亲 吻 脚 大 流 行 新 娘 玉 足 怎 样 才 有 好 味 道 ? | |||
game 黄 金 吊 打 钻 石 ? L O L 最 强 刷 钱 毒 瘤 打 法 诞 生 | |||
history 奇 事 ! 上 万 只 青 蛙 拦 路 告 状 , 竟 然 牵 扯 出 一 桩 命 案 | |||
baby 奶 奶 , 你 为 什 么 不 让 我 用 尿 不 湿 | |||
game L O L 当 5 个 大 发 明 家 炮 台 围 住 泉 水 的 时 候 : 这 是 真 虐 泉 ! | |||
essay 文 友 忠 告 暖 人 心 : 人 到 中 年 “ 不 交 五 友 ” | |||
travel 这 一 年 , 我 们 去 日 本 | |||
food 好 吃 早 饭 近 似 吃 补 药 | |||
fashion 夏 天 太 热 , 唇 膏 化 了 如 何 办 ? | |||
society 厂 里 面 的 9 0 后 打 工 妹 , 辛 苦 来 之 不 易 | |||
history 罕 见 老 照 片 展 示 美 国 大 萧 条 时 期 景 象 | |||
world 美 国 总 统 奥 巴 马 , 是 童 心 未 泯 的 温 情 奥 大 大 , 还 是 个 超 级 老 顽 童 | |||
finance 脱 欧 公 投 前 一 天 抛 售 英 镑 这 一 次 索 罗 斯 也 被 “ 打 败 ” 了 . . . | |||
history 翻 越 长 征 路 上 第 一 座 大 山 | |||
world 朝 鲜 批 奥 巴 马 涉 朝 言 论 , 称 只 要 核 威 胁 存 在 将 继 续 强 化 核 武 力 量 | |||
game 《 巫 师 3 : 狂 猎 》 不 良 因 素 解 析 攻 略 | |||
travel 在 郑 州 有 个 地 方 , 时 光 仿 佛 在 那 儿 停 下 脚 步 | |||
history 它 号 称 “ 天 下 第 一 团 ” , 走 出 过 1 4 位 共 和 国 将 军 以 及 一 位 著 名 作 家 | |||
car 煤 老 板 去 黄 江 买 车 , 以 为 占 了 便 宜 没 想 被 坑 了 1 0 0 多 万 | |||
society “ 试 管 婴 儿 之 母 ” 张 丽 珠 遗 体 告 别 仪 式 8 日 举 行 | |||
sports 东 京 奥 运 会 , 中 国 女 排 卫 冕 的 几 率 有 多 大 ? | |||
travel 成 都 我 们 永 远 依 恋 的 城 市 | |||
tech 雷 布 斯 除 了 小 米 还 有 这 些 秘 密 , 你 知 道 吗 ? | |||
world “ 仲 裁 庭 损 害 国 际 法 体 系 公 正 性 ” — — 访 武 汉 大 学 中 国 边 界 与 海 洋 研 究 院 首 席 专 家 易 显 河 | |||
entertainment 上 海 观 众 和 欧 洲 三 大 影 展 之 间 的 距 离 : 零 时 差 | |||
essay 关 系 好 , 一 切 便 好 | |||
baby 刚 出 生 不 到 1 小 时 的 白 鲸 宝 宝 被 冲 上 岸 , 被 救 后 对 恩 人 露 出 微 笑 | |||
tech 赚 足 眼 球 , 诺 基 亚 五 边 形 W i n 1 0 M o b i l e 概 念 手 机 : 棱 镜 | |||
essay 2 4 句 经 典 语 录 : 穷 三 年 可 以 怨 命 , 穷 十 年 就 得 自 省 | |||
food 这 道 菜 真 下 饭 ! 做 法 简 单 , 防 辐 射 、 抗 衰 老 , 关 键 还 便 宜 | |||
entertainment 《 继 承 者 们 》 要 拍 中 国 版 , 众 角 色 你 期 待 谁 来 演 ? | |||
game D N F 暴 走 改 版 后 怎 么 样 D N F 暴 走 改 版 红 眼 变 弱 了 吗 | |||
entertainment 郑 佩 佩 自 曝 与 李 小 龙 的 过 去 他 是 个 “ 疯 子 ” | |||
baby 女 性 只 有 8 4 次 最 佳 受 孕 机 会 | |||
travel 月 初 一 个 人 去 了 日 本 . . | |||
military 不 为 人 知 的 8 0 万 苏 联 女 兵 ! 最 后 一 张 很 美 ! | |||
tech 网 络 商 家 提 供 小 米 5 运 存 升 级 服 务 : 3 G B 秒 变 6 G B | |||
history 宋 太 祖 、 宋 太 宗 凌 辱 亡 国 皇 后 , 徽 钦 二 帝 后 宫 被 金 人 凌 辱 | |||
history 人 有 三 面 最 “ 难 吃 ” ! 黑 帮 大 佬 杜 月 笙 论 江 湖 规 矩 ! 一 生 只 怕 这 一 人 | |||
game 来 了 ! 索 尼 P S 4 独 占 大 作 《 战 神 4 》 正 式 公 布 | |||
discovery 延 时 视 频 显 示 珊 瑚 如 何 “ 驱 逐 ” 共 生 藻 类 | |||
car 传 祺 G A 8 和 东 风 A 9 谁 才 是 自 主 “ 豪 车 ” 大 佬 | |||
fashion 娶 老 婆 就 要 娶 这 种 ! 蒋 欣 这 样 微 胖 的 女 人 好 看 又 实 用 | |||
sports 黄 山 姑 娘 吕 秀 芝 勇 夺 奥 运 铜 牌 数 百 父 老 彻 夜 为 她 加 油 | |||
military [ 每 日 军 图 ] 土 豪 补 仓 ! 沙 特 再 次 购 买 上 百 辆 美 国 M 1 A 2 主 战 坦 克 | |||
military 美 军 这 款 武 器 号 称 能 让 半 个 中 国 陷 入 黑 暗 , 解 放 军 少 将 : 我 们 也 有 | |||
world 邓 小 平 与 日 本 天 皇 的 历 史 性 会 谈 , 对 中 日 两 国 都 具 有 深 远 的 意 义 啊 ! | |||
baby 为 什 么 有 人 上 个 厕 所 都 能 生 出 孩 子 ? | |||
fashion 欣 宜 举 行 首 次 个 唱 十 万 颗 宝 仕 奥 莎 仿 水 晶 闪 耀 全 场 | |||
food 小 两 口 上 周 的 晚 餐 | |||
society 在 北 京 就 要 守 规 矩 | |||
entertainment 知 情 人 曝 翰 爽 分 手 内 幕 : 郑 爽 想 结 婚 却 被 一 直 拖 着 | |||
military 中 国 反 舰 导 弹 世 界 第 一 远 远 超 过 美 国 但 为 何 却 还 不 如 俄 罗 斯 ? | |||
entertainment 他 除 了 是 《 我 歌 》 音 乐 总 监 , 还 曾 组 乐 队 玩 摇 滚 , 是 黄 家 驹 旧 日 知 己 | |||
baby 长 鹅 口 疮 的 孩 子 怎 么 照 顾 ? 不 要 再 说 拿 他 没 办 法 了 ! | |||
discovery 微 重 力 不 需 使 用 肌 肉 , 太 空 人 返 回 地 球 后 脊 椎 旁 肌 肉 萎 缩 约 1 9 % | |||
regimen 这 6 种 人 将 来 会 得 老 年 痴 呆 ! 预 防 老 年 痴 呆 症 , 这 些 办 法 被 全 世 界 公 认 | |||
society 2 0 1 6 年 上 海 即 将 发 生 哪 些 大 事 件 。 。 。 。 | |||
car 北 汽 自 主 品 牌 亏 损 3 3 . 4 1 亿 额 外 促 销 成 主 因 | |||
car 在 那 山 的 那 边 海 的 那 边 , 有 一 群 自 由 侠 | |||
history 一 个 小 城 就 屠 杀 了 4 0 0 0 苏 军 战 俘 , 希 特 勒 死 神 战 队 的 崛 起 与 覆 灭 | |||
baby 给 孩 子 洗 澡 时 , 这 些 部 位 再 脏 也 不 要 碰 ! | |||
essay 好 久 不 见 , 你 还 好 么 | |||
baby 被 娃 误 伤 的 9 种 痛 , 数 一 数 你 中 了 几 枪 ? | |||
food 初 秋 的 小 炖 品 放 冰 糖 就 比 较 滋 润 , 放 红 糖 就 补 血 又 不 燥 热 | |||
game 佩 服 佩 服 ! 羊 驼 D e f t 单 排 重 回 韩 服 最 强 王 者 第 一 名 ! | |||
game 三 个 时 代 的 标 志 炉 石 传 说 三 大 远 古 毒 瘤 卡 组 | |||
discovery 2 0 世 纪 最 伟 大 科 学 发 现 — — 魔 术 般 的 超 导 材 料 ! |
@@ -3,14 +3,14 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.core.trainer import POSTrainer | |||
from fastNLP.core.trainer import SeqLabelTrainer | |||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.core.tester import SeqLabelTester | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.core.inference import Inference | |||
from fastNLP.core.inference import SeqLabelInfer | |||
data_name = "people.txt" | |||
data_path = "data_for_tests/people.txt" | |||
@@ -50,14 +50,15 @@ def infer(): | |||
""" | |||
# Inference interface | |||
infer = Inference(pickle_path) | |||
infer = SeqLabelInfer(pickle_path) | |||
results = infer.predict(model, infer_data) | |||
print(results) | |||
for res in results: | |||
print(res) | |||
print("Inference finished!") | |||
def train_test(): | |||
def train_and_test(): | |||
# Config Loader | |||
train_args = ConfigSection() | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||
@@ -67,12 +68,12 @@ def train_test(): | |||
train_data = pos_loader.load_lines() | |||
# Preprocessor | |||
p = POSPreprocess(train_data, pickle_path) | |||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.5) | |||
train_args["vocab_size"] = p.vocab_size | |||
train_args["num_classes"] = p.num_classes | |||
# Trainer | |||
trainer = POSTrainer(train_args) | |||
trainer = SeqLabelTrainer(train_args) | |||
# Model | |||
model = SeqLabeling(train_args) | |||
@@ -100,7 +101,7 @@ def train_test(): | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
# Tester | |||
tester = POSTester(test_args) | |||
tester = SeqLabelTester(test_args) | |||
# Start testing | |||
tester.test(model) | |||
@@ -111,5 +112,5 @@ def train_test(): | |||
if __name__ == "__main__": | |||
train_test() | |||
# infer() | |||
# train_and_test() | |||
infer() |
@@ -3,12 +3,12 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.core.trainer import POSTrainer | |||
from fastNLP.core.trainer import SeqLabelTrainer | |||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.core.tester import SeqLabelTester | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.core.inference import Inference | |||
@@ -73,7 +73,7 @@ def train_test(): | |||
train_args["num_classes"] = p.num_classes | |||
# Trainer | |||
trainer = POSTrainer(train_args) | |||
trainer = SeqLabelTrainer(train_args) | |||
# Model | |||
model = SeqLabeling(train_args) | |||
@@ -101,7 +101,7 @@ def train_test(): | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
# Tester | |||
tester = POSTester(test_args) | |||
tester = SeqLabelTester(test_args) | |||
# Start testing | |||
tester.test(model) | |||
@@ -113,4 +113,4 @@ def train_test(): | |||
if __name__ == "__main__": | |||
train_test() | |||
#infer() | |||
infer() |
@@ -1,4 +1,4 @@ | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.core.tester import SeqLabelTester | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | |||
from fastNLP.loader.preprocess import POSPreprocess | |||
@@ -26,7 +26,7 @@ def foo(): | |||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | |||
"use_cuda": True} | |||
validator = POSTester(valid_args) | |||
validator = SeqLabelTester(valid_args) | |||
print("start validation.") | |||
validator.test(model) | |||
@@ -0,0 +1,84 @@ | |||
# Python: 3.5 | |||
# encoding: utf-8 | |||
import os | |||
from fastNLP.core.inference import ClassificationInfer | |||
from fastNLP.core.trainer import ClassificationTrainer | |||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.loader.preprocess import ClassPreprocess | |||
from fastNLP.models.cnn_text_classification import CNNText | |||
from fastNLP.saver.model_saver import ModelSaver | |||
data_dir = "./data_for_tests/" | |||
train_file = 'text_classify.txt' | |||
model_name = "model_class.pkl" | |||
def infer(): | |||
# load dataset | |||
print("Loading data...") | |||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||
data = ds_loader.load() | |||
unlabeled_data = [x[0] for x in data] | |||
# pre-process data | |||
pre = ClassPreprocess(data_dir) | |||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
print("vocabulary size:", vocab_size) | |||
print("number of classes:", n_classes) | |||
# construct model | |||
print("Building model...") | |||
cnn = CNNText(class_num=n_classes, embed_num=vocab_size) | |||
# Dump trained parameters into the model | |||
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") | |||
print("model loaded!") | |||
infer = ClassificationInfer(data_dir) | |||
results = infer.predict(cnn, unlabeled_data) | |||
print(results) | |||
def train(): | |||
# load dataset | |||
print("Loading data...") | |||
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) | |||
data = ds_loader.load() | |||
print(data[0]) | |||
# pre-process data | |||
pre = ClassPreprocess(data_dir) | |||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
print("vocabulary size:", vocab_size) | |||
print("number of classes:", n_classes) | |||
# construct model | |||
print("Building model...") | |||
cnn = CNNText(class_num=n_classes, embed_num=vocab_size) | |||
# train | |||
print("Training...") | |||
train_args = { | |||
"epochs": 1, | |||
"batch_size": 10, | |||
"pickle_path": data_dir, | |||
"validate": False, | |||
"save_best_dev": False, | |||
"model_saved_path": "./data_for_tests/", | |||
"use_cuda": True | |||
} | |||
trainer = ClassificationTrainer(train_args) | |||
trainer.train(cnn) | |||
print("Training finished!") | |||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||
saver.save_pytorch(cnn) | |||
print("Model saved!") | |||
if __name__ == "__main__": | |||
# train() | |||
infer() |