@@ -18,7 +18,7 @@ pre-processing data, constructing model and training model. | |||||
from fastNLP.modules import aggregation | from fastNLP.modules import aggregation | ||||
from fastNLP.modules import decoder | from fastNLP.modules import decoder | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.loader.preprocess import ClassPreprocess | from fastNLP.loader.preprocess import ClassPreprocess | ||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.core.inference import ClassificationInfer | from fastNLP.core.inference import ClassificationInfer | ||||
@@ -50,7 +50,7 @@ pre-processing data, constructing model and training model. | |||||
train_path = 'test/data_for_tests/text_classify.txt' # training set file | train_path = 'test/data_for_tests/text_classify.txt' # training set file | ||||
# load dataset | # load dataset | ||||
ds_loader = ClassDatasetLoader("train", train_path) | |||||
ds_loader = ClassDataSetLoader("train", train_path) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
# pre-process dataset | # pre-process dataset | ||||
@@ -3,7 +3,7 @@ from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.predictor import ClassificationInfer | from fastNLP.core.predictor import ClassificationInfer | ||||
from fastNLP.core.preprocess import ClassPreprocess | from fastNLP.core.preprocess import ClassPreprocess | ||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import aggregator | from fastNLP.modules import aggregator | ||||
from fastNLP.modules import decoder | from fastNLP.modules import decoder | ||||
@@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model | |||||
train_path = './data_for_tests/text_classify.txt' # training set file | train_path = './data_for_tests/text_classify.txt' # training set file | ||||
# load dataset | # load dataset | ||||
ds_loader = ClassDatasetLoader(train_path) | |||||
ds_loader = ClassDataSetLoader() | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
# pre-process dataset | # pre-process dataset | ||||
@@ -17,7 +17,7 @@ class Batch(object): | |||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
:param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
:param sampler: a Sampler object | :param sampler: a Sampler object | ||||
:param use_cuda: bool, whetjher to use GPU | |||||
:param use_cuda: bool, whether to use GPU | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
@@ -37,15 +37,12 @@ class Batch(object): | |||||
""" | """ | ||||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | ||||
batch_x also contains an item (str: list of int) about origin lengths, | |||||
which means ("field_name_origin_len": origin lengths). | |||||
E.g. | E.g. | ||||
:: | :: | ||||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | ||||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | ||||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | ||||
The names of fields are defined in preprocessor's convert_to_dataset method. | |||||
""" | """ | ||||
if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
@@ -54,10 +51,9 @@ class Batch(object): | |||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
padding_length = {field_name: max(field_length[self.curidx: endidx]) | padding_length = {field_name: max(field_length[self.curidx: endidx]) | ||||
for field_name, field_length in self.lengths.items()} | for field_name, field_length in self.lengths.items()} | ||||
origin_lengths = {field_name: field_length[self.curidx: endidx] | |||||
for field_name, field_length in self.lengths.items()} | |||||
batch_x, batch_y = defaultdict(list), defaultdict(list) | batch_x, batch_y = defaultdict(list), defaultdict(list) | ||||
# transform index to tensor and do padding for sequences | |||||
for idx in range(self.curidx, endidx): | for idx in range(self.curidx, endidx): | ||||
x, y = self.dataset.to_tensor(idx, padding_length) | x, y = self.dataset.to_tensor(idx, padding_length) | ||||
for name, tensor in x.items(): | for name, tensor in x.items(): | ||||
@@ -65,8 +61,7 @@ class Batch(object): | |||||
for name, tensor in y.items(): | for name, tensor in y.items(): | ||||
batch_y[name].append(tensor) | batch_y[name].append(tensor) | ||||
batch_origin_length = {} | |||||
# combine instances into a batch | |||||
# combine instances to form a batch | |||||
for batch in (batch_x, batch_y): | for batch in (batch_x, batch_y): | ||||
for name, tensor_list in batch.items(): | for name, tensor_list in batch.items(): | ||||
if self.use_cuda: | if self.use_cuda: | ||||
@@ -74,14 +69,6 @@ class Batch(object): | |||||
else: | else: | ||||
batch[name] = torch.stack(tensor_list, dim=0) | batch[name] = torch.stack(tensor_list, dim=0) | ||||
# add origin lengths in batch_x | |||||
for name, tensor in batch_x.items(): | |||||
if self.use_cuda: | |||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda() | |||||
else: | |||||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]) | |||||
batch_x.update(batch_origin_length) | |||||
self.curidx += endidx | self.curidx += endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
@@ -1,7 +1,11 @@ | |||||
import random | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from copy import deepcopy | |||||
from fastNLP.core.field import TextField | |||||
from fastNLP.core.field import TextField, LabelField | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader | |||||
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | ||||
@@ -65,7 +69,8 @@ class DataSet(list): | |||||
"""A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
""" | """ | ||||
def __init__(self, name="", instances=None): | |||||
def __init__(self, name="", instances=None, loader=None): | |||||
""" | """ | ||||
:param name: str, the name of the dataset. (default: "") | :param name: str, the name of the dataset. (default: "") | ||||
@@ -76,6 +81,7 @@ class DataSet(list): | |||||
self.name = name | self.name = name | ||||
if instances is not None: | if instances is not None: | ||||
self.extend(instances) | self.extend(instances) | ||||
self.dataset_loader = loader | |||||
def index_all(self, vocab): | def index_all(self, vocab): | ||||
for ins in self: | for ins in self: | ||||
@@ -109,3 +115,171 @@ class DataSet(list): | |||||
for field_name, field_length in ins.get_length().items(): | for field_name, field_length in ins.get_length().items(): | ||||
lengths[field_name].append(field_length) | lengths[field_name].append(field_length) | ||||
return lengths | return lengths | ||||
def convert(self, data): | |||||
"""Convert lists of strings into Instances with Fields""" | |||||
raise NotImplementedError | |||||
def convert_with_vocabs(self, data, vocabs): | |||||
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting.""" | |||||
raise NotImplementedError | |||||
def convert_for_infer(self, data, vocabs): | |||||
"""Convert lists of strings into Instances with Fields.""" | |||||
def load(self, data_path, vocabs=None, infer=False): | |||||
"""Load data from the given files. | |||||
:param data_path: str, the path to the data | |||||
:param infer: bool. If True, there is no label information in the data. Default: False. | |||||
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | |||||
""" | |||||
raw_data = self.dataset_loader.load(data_path) | |||||
if infer is True: | |||||
self.convert_for_infer(raw_data, vocabs) | |||||
else: | |||||
if vocabs is not None: | |||||
self.convert_with_vocabs(raw_data, vocabs) | |||||
else: | |||||
self.convert(raw_data) | |||||
def split(self, ratio, shuffle=True): | |||||
"""Train/dev splitting | |||||
:param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||||
:param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||||
:return train_set: a DataSet object, representing the training set | |||||
dev_set: a DataSet object, representing the validation set | |||||
""" | |||||
assert 0 < ratio < 1 | |||||
if shuffle: | |||||
random.shuffle(self) | |||||
split_idx = int(len(self) * ratio) | |||||
dev_set = deepcopy(self) | |||||
train_set = deepcopy(self) | |||||
del train_set[:split_idx] | |||||
del dev_set[split_idx:] | |||||
return train_set, dev_set | |||||
class SeqLabelDataSet(DataSet): | |||||
def __init__(self, instances=None, loader=POSDataSetLoader()): | |||||
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader) | |||||
self.word_vocab = Vocabulary() | |||||
self.label_vocab = Vocabulary() | |||||
def convert(self, data): | |||||
"""Convert lists of strings into Instances with Fields. | |||||
:param data: 3-level lists. Entries are strings. | |||||
""" | |||||
for example in data: | |||||
word_seq, label_seq = example[0], example[1] | |||||
# list, list | |||||
self.word_vocab.update(word_seq) | |||||
self.label_vocab.update(label_seq) | |||||
x = TextField(word_seq, is_target=False) | |||||
x_len = LabelField(len(word_seq), is_target=False) | |||||
y = TextField(label_seq, is_target=False) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
instance.add_field("truth", y) | |||||
instance.add_field("word_seq_origin_len", x_len) | |||||
self.append(instance) | |||||
self.index_field("word_seq", self.word_vocab) | |||||
self.index_field("truth", self.label_vocab) | |||||
# no need to index "word_seq_origin_len" | |||||
def convert_with_vocabs(self, data, vocabs): | |||||
for example in data: | |||||
word_seq, label_seq = example[0], example[1] | |||||
# list, list | |||||
x = TextField(word_seq, is_target=False) | |||||
x_len = LabelField(len(word_seq), is_target=False) | |||||
y = TextField(label_seq, is_target=False) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
instance.add_field("truth", y) | |||||
instance.add_field("word_seq_origin_len", x_len) | |||||
self.append(instance) | |||||
self.index_field("word_seq", vocabs["word_vocab"]) | |||||
self.index_field("truth", vocabs["label_vocab"]) | |||||
# no need to index "word_seq_origin_len" | |||||
def convert_for_infer(self, data, vocabs): | |||||
for word_seq in data: | |||||
# list | |||||
x = TextField(word_seq, is_target=False) | |||||
x_len = LabelField(len(word_seq), is_target=False) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
instance.add_field("word_seq_origin_len", x_len) | |||||
self.append(instance) | |||||
self.index_field("word_seq", vocabs["word_vocab"]) | |||||
# no need to index "word_seq_origin_len" | |||||
class TextClassifyDataSet(DataSet): | |||||
def __init__(self, instances=None, loader=ClassDataSetLoader()): | |||||
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader) | |||||
self.word_vocab = Vocabulary() | |||||
self.label_vocab = Vocabulary(need_default=False) | |||||
def convert(self, data): | |||||
for example in data: | |||||
word_seq, label = example[0], example[1] | |||||
# list, str | |||||
self.word_vocab.update(word_seq) | |||||
self.label_vocab.update(label) | |||||
x = TextField(word_seq, is_target=False) | |||||
y = LabelField(label, is_target=True) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
instance.add_field("label", y) | |||||
self.append(instance) | |||||
self.index_field("word_seq", self.word_vocab) | |||||
self.index_field("label", self.label_vocab) | |||||
def convert_with_vocabs(self, data, vocabs): | |||||
for example in data: | |||||
word_seq, label = example[0], example[1] | |||||
# list, str | |||||
x = TextField(word_seq, is_target=False) | |||||
y = LabelField(label, is_target=True) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
instance.add_field("label", y) | |||||
self.append(instance) | |||||
self.index_field("word_seq", vocabs["word_vocab"]) | |||||
self.index_field("label", vocabs["label_vocab"]) | |||||
def convert_for_infer(self, data, vocabs): | |||||
for word_seq in data: | |||||
# list | |||||
x = TextField(word_seq, is_target=False) | |||||
instance = Instance() | |||||
instance.add_field("word_seq", x) | |||||
self.append(instance) | |||||
self.index_field("word_seq", vocabs["word_vocab"]) | |||||
def change_field_is_target(data_set, field_name, new_target): | |||||
"""Change the flag of is_target in a field. | |||||
:param data_set: a DataSet object | |||||
:param field_name: str, the name of the field | |||||
:param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither. | |||||
""" | |||||
for inst in data_set: | |||||
inst.fields[field_name].is_target = new_target | |||||
if __name__ == "__main__": | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load("../../test/data_for_tests/people.txt") | |||||
a, b = data_set.split(0.3) | |||||
print(type(data_set), type(a), type(b)) | |||||
print(len(data_set), len(a), len(b)) |
@@ -59,6 +59,9 @@ class TextField(Field): | |||||
class LabelField(Field): | class LabelField(Field): | ||||
"""The Field representing a single label. Can be a string or integer. | |||||
""" | |||||
def __init__(self, label, is_target=True): | def __init__(self, label, is_target=True): | ||||
super(LabelField, self).__init__(is_target) | super(LabelField, self).__init__(is_target) | ||||
self.label = label | self.label = label | ||||
@@ -73,13 +76,14 @@ class LabelField(Field): | |||||
def index(self, vocab): | def index(self, vocab): | ||||
if self._index is None: | if self._index is None: | ||||
self._index = vocab[self.label] | |||||
if isinstance(self.label, str): | |||||
self._index = vocab[self.label] | |||||
return self._index | return self._index | ||||
def to_tensor(self, padding_length): | def to_tensor(self, padding_length): | ||||
if self._index is None: | if self._index is None: | ||||
if isinstance(self.label, int): | if isinstance(self.label, int): | ||||
return torch.LongTensor([self.label]) | |||||
return torch.tensor(self.label) | |||||
elif isinstance(self.label, str): | elif isinstance(self.label, str): | ||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | ||||
else: | else: | ||||
@@ -46,8 +46,11 @@ class Instance(object): | |||||
tensor_x = {} | tensor_x = {} | ||||
tensor_y = {} | tensor_y = {} | ||||
for name, field in self.fields.items(): | for name, field in self.fields.items(): | ||||
if field.is_target: | |||||
if field.is_target is True: | |||||
tensor_y[name] = field.to_tensor(padding_length[name]) | tensor_y[name] = field.to_tensor(padding_length[name]) | ||||
else: | |||||
elif field.is_target is False: | |||||
tensor_x[name] = field.to_tensor(padding_length[name]) | tensor_x[name] = field.to_tensor(padding_length[name]) | ||||
else: | |||||
# is_target is None | |||||
continue | |||||
return tensor_x, tensor_y | return tensor_x, tensor_y |
@@ -39,8 +39,19 @@ class Loss(object): | |||||
:return loss: a PyTorch loss | :return loss: a PyTorch loss | ||||
""" | """ | ||||
class InnerCrossEntropy: | |||||
"""A simple wrapper to guarantee input shapes.""" | |||||
def __init__(self): | |||||
self.f = torch.nn.CrossEntropyLoss() | |||||
def __call__(self, predict, truth): | |||||
truth = truth.view(-1, ) | |||||
return self.f(predict, truth) | |||||
if loss_name == "cross_entropy": | if loss_name == "cross_entropy": | ||||
return torch.nn.CrossEntropyLoss() | |||||
return InnerCrossEntropy() | |||||
elif loss_name == 'nll': | elif loss_name == 'nll': | ||||
return torch.nn.NLLLoss() | return torch.nn.NLLLoss() | ||||
else: | else: | ||||
@@ -4,6 +4,51 @@ import numpy as np | |||||
import torch | import torch | ||||
class Evaluator(object): | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, predict, truth): | |||||
""" | |||||
:param predict: list of tensors, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return: | |||||
""" | |||||
raise NotImplementedError | |||||
class ClassifyEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(ClassifyEvaluator, self).__init__() | |||||
def __call__(self, predict, truth): | |||||
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
y_true = torch.cat(truth, dim=0) | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
return {"accuracy": acc} | |||||
class SeqLabelEvaluator(Evaluator): | |||||
def __init__(self): | |||||
super(SeqLabelEvaluator, self).__init__() | |||||
def __call__(self, predict, truth): | |||||
""" | |||||
:param predict: list of tensors, the network outputs from all batches. | |||||
:param truth: list of dict, the ground truths from all batch_y. | |||||
:return accuracy: | |||||
""" | |||||
truth = [item["truth"] for item in truth] | |||||
truth = torch.cat(truth).view(-1, ) | |||||
results = torch.Tensor(predict).view(-1, ) | |||||
accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0] | |||||
return {"accuracy": float(accuracy)} | |||||
def _conver_numpy(x): | def _conver_numpy(x): | ||||
"""convert input data to numpy array | """convert input data to numpy array | ||||
@@ -16,18 +16,18 @@ class Predictor(object): | |||||
Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
""" | """ | ||||
def __init__(self, pickle_path, task): | |||||
def __init__(self, pickle_path, post_processor): | |||||
""" | """ | ||||
:param pickle_path: str, the path to the pickle files. | :param pickle_path: str, the path to the pickle files. | ||||
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify"). | |||||
:param post_processor: a function or callable object, that takes list of batch outputs as input | |||||
""" | """ | ||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
self.pickle_path = pickle_path | self.pickle_path = pickle_path | ||||
self._task = task # one of ("seq_label", "text_classify") | |||||
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") | |||||
self._post_processor = post_processor | |||||
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") | |||||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | ||||
def predict(self, network, data): | def predict(self, network, data): | ||||
@@ -38,21 +38,20 @@ class Predictor(object): | |||||
:return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
""" | """ | ||||
# transform strings into DataSet object | # transform strings into DataSet object | ||||
data = self.prepare_input(data) | |||||
# data = self.prepare_input(data) | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
self.batch_output.clear() | |||||
batch_output = [] | |||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | ||||
for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
batch_output.append(prediction) | |||||
self.batch_output.append(prediction) | |||||
return self.prepare_output(self.batch_output) | |||||
return self._post_processor(batch_output, self.label_vocab) | |||||
def mode(self, network, test=True): | def mode(self, network, test=True): | ||||
if test: | if test: | ||||
@@ -62,13 +61,7 @@ class Predictor(object): | |||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
"""Forward through network.""" | """Forward through network.""" | ||||
if self._task == "seq_label": | |||||
y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
y = network.prediction(y) | |||||
elif self._task == "text_classify": | |||||
y = network(x["word_seq"]) | |||||
else: | |||||
raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
y = network(**x) | |||||
return y | return y | ||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
@@ -88,39 +81,32 @@ class Predictor(object): | |||||
assert isinstance(data, list) | assert isinstance(data, list) | ||||
return create_dataset_from_lists(data, self.word_vocab, has_target=False) | return create_dataset_from_lists(data, self.word_vocab, has_target=False) | ||||
def prepare_output(self, data): | |||||
"""Transform list of batch outputs into strings.""" | |||||
if self._task == "seq_label": | |||||
return self._seq_label_prepare_output(data) | |||||
elif self._task == "text_classify": | |||||
return self._text_classify_prepare_output(data) | |||||
else: | |||||
raise NotImplementedError("Unknown task type {}".format(self._task)) | |||||
def _seq_label_prepare_output(self, batch_outputs): | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([self.label_vocab.to_word(int(x)) for x in example]) | |||||
return results | |||||
def _text_classify_prepare_output(self, batch_outputs): | |||||
results = [] | |||||
for batch_out in batch_outputs: | |||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
results.extend([self.label_vocab.to_word(i) for i in idx]) | |||||
return results | |||||
class SeqLabelInfer(Predictor): | class SeqLabelInfer(Predictor): | ||||
def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
print( | print( | ||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.") | |||||
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label") | |||||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") | |||||
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) | |||||
class ClassificationInfer(Predictor): | class ClassificationInfer(Predictor): | ||||
def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
print( | print( | ||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.") | |||||
super(ClassificationInfer, self).__init__(pickle_path, "text_classify") | |||||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") | |||||
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) | |||||
def seq_label_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch in batch_outputs: | |||||
for example in np.array(batch): | |||||
results.append([label_vocab.to_word(int(x)) for x in example]) | |||||
return results | |||||
def text_classify_post_processor(batch_outputs, label_vocab): | |||||
results = [] | |||||
for batch_out in batch_outputs: | |||||
idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
results.extend([label_vocab.to_word(i) for i in idx]) | |||||
return results |
@@ -114,7 +114,6 @@ class Preprocessor(object): | |||||
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | ||||
is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | ||||
""" | """ | ||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | ||||
self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | ||||
self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | ||||
@@ -1,7 +1,7 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
@@ -22,28 +22,23 @@ class Tester(object): | |||||
"kwargs" must have the same type as "default_args" on corresponding keys. | "kwargs" must have the same type as "default_args" on corresponding keys. | ||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"save_output": True, # collect outputs of validation set | |||||
"save_loss": True, # collect losses in validation | |||||
"save_best_dev": False, # save best model during validation | |||||
"batch_size": 8, | |||||
default_args = {"batch_size": 8, | |||||
"use_cuda": False, | "use_cuda": False, | ||||
"pickle_path": "./save/", | "pickle_path": "./save/", | ||||
"model_name": "dev_best_model.pkl", | "model_name": "dev_best_model.pkl", | ||||
"print_every_step": 1, | |||||
"evaluator": Evaluator() | |||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {"task" # one of ("seq_label", "text_classify") | |||||
} | |||||
required_args = {} | |||||
for req_key in required_args: | for req_key in required_args: | ||||
if req_key not in kwargs: | if req_key not in kwargs: | ||||
logger.error("Tester lacks argument {}".format(req_key)) | logger.error("Tester lacks argument {}".format(req_key)) | ||||
raise ValueError("Tester lacks argument {}".format(req_key)) | raise ValueError("Tester lacks argument {}".format(req_key)) | ||||
self._task = kwargs["task"] | |||||
for key in default_args: | for key in default_args: | ||||
if key in kwargs: | if key in kwargs: | ||||
@@ -59,17 +54,13 @@ class Tester(object): | |||||
pass | pass | ||||
print(default_args) | print(default_args) | ||||
self.save_output = default_args["save_output"] | |||||
self.save_best_dev = default_args["save_best_dev"] | |||||
self.save_loss = default_args["save_loss"] | |||||
self.batch_size = default_args["batch_size"] | self.batch_size = default_args["batch_size"] | ||||
self.pickle_path = default_args["pickle_path"] | self.pickle_path = default_args["pickle_path"] | ||||
self.use_cuda = default_args["use_cuda"] | self.use_cuda = default_args["use_cuda"] | ||||
self.print_every_step = default_args["print_every_step"] | |||||
self._evaluator = default_args["evaluator"] | |||||
self._model = None | self._model = None | ||||
self.eval_history = [] # evaluation results of all batches | self.eval_history = [] # evaluation results of all batches | ||||
self.batch_output = [] # outputs of all batches | |||||
def test(self, network, dev_data): | def test(self, network, dev_data): | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
@@ -80,26 +71,18 @@ class Tester(object): | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, is_test=True) | self.mode(network, is_test=True) | ||||
self.eval_history.clear() | self.eval_history.clear() | ||||
self.batch_output.clear() | |||||
output_list = [] | |||||
truth_list = [] | |||||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | ||||
step = 0 | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
eval_results = self.evaluate(prediction, batch_y) | |||||
if self.save_output: | |||||
self.batch_output.append(prediction) | |||||
if self.save_loss: | |||||
self.eval_history.append(eval_results) | |||||
print_output = "[test step {}] {}".format(step, eval_results) | |||||
logger.info(print_output) | |||||
if self.print_every_step > 0 and step % self.print_every_step == 0: | |||||
print(self.make_eval_output(prediction, eval_results)) | |||||
step += 1 | |||||
output_list.append(prediction) | |||||
truth_list.append(batch_y) | |||||
eval_results = self.evaluate(output_list, truth_list) | |||||
print("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -121,104 +104,30 @@ class Tester(object): | |||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
"""Compute evaluation metrics. | """Compute evaluation metrics. | ||||
:param predict: Tensor | |||||
:param truth: Tensor | |||||
:param predict: list of Tensor | |||||
:param truth: list of dict | |||||
:return eval_results: can be anything. It will be stored in self.eval_history | :return eval_results: can be anything. It will be stored in self.eval_history | ||||
""" | """ | ||||
if "label_seq" in truth: | |||||
truth = truth["label_seq"] | |||||
elif "label" in truth: | |||||
truth = truth["label"] | |||||
else: | |||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
return self._evaluator(predict, truth) | |||||
if self._task == "seq_label": | |||||
return self._seq_label_evaluate(predict, truth) | |||||
elif self._task == "text_classify": | |||||
return self._text_classify_evaluate(predict, truth) | |||||
else: | |||||
raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
def _seq_label_evaluate(self, predict, truth): | |||||
batch_size, max_len = predict.size(0), predict.size(1) | |||||
loss = self._model.loss(predict, truth) / batch_size | |||||
prediction = self._model.prediction(predict) | |||||
# pad prediction to equal length | |||||
for pred in prediction: | |||||
if len(pred) < max_len: | |||||
pred += [0] * (max_len - len(pred)) | |||||
results = torch.Tensor(prediction).view(-1, ) | |||||
# make sure "results" is in the same device as "truth" | |||||
results = results.to(truth) | |||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||||
return [float(loss), float(accuracy)] | |||||
def _text_classify_evaluate(self, y_logit, y_true): | |||||
y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
return [y_prob, y_true] | |||||
@property | |||||
def metrics(self): | |||||
"""Compute and return metrics. | |||||
Use self.eval_history to compute metrics over the whole dev set. | |||||
Please refer to metrics.py for common metric functions. | |||||
:return : variable number of outputs | |||||
""" | |||||
if self._task == "seq_label": | |||||
return self._seq_label_metrics | |||||
elif self._task == "text_classify": | |||||
return self._text_classify_metrics | |||||
else: | |||||
raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
@property | |||||
def _seq_label_metrics(self): | |||||
batch_loss = np.mean([x[0] for x in self.eval_history]) | |||||
batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||||
return batch_loss, batch_accuracy | |||||
@property | |||||
def _text_classify_metrics(self): | |||||
y_prob, y_true = zip(*self.eval_history) | |||||
y_prob = torch.cat(y_prob, dim=0) | |||||
y_pred = torch.argmax(y_prob, dim=-1) | |||||
y_true = torch.cat(y_true, dim=0) | |||||
acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
def show_metrics(self): | |||||
"""Customize evaluation outputs in Trainer. | |||||
Called by Trainer to print evaluation results on dev set during training. | |||||
Use self.metrics to fetch available metrics. | |||||
:return print_str: str | |||||
""" | |||||
loss, accuracy = self.metrics | |||||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||||
def print_eval_results(self, results): | |||||
"""Override this method to support more print formats. | |||||
def make_eval_output(self, predictions, eval_results): | |||||
"""Customize Tester outputs. | |||||
:param results: dict, (str: float) is (metrics name: value) | |||||
:param predictions: Tensor | |||||
:param eval_results: Tensor | |||||
:return: str, to be printed. | |||||
""" | """ | ||||
return self.show_metrics() | |||||
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
class SeqLabelTester(Tester): | class SeqLabelTester(Tester): | ||||
def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
test_args.update({"task": "seq_label"}) | |||||
print( | print( | ||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | |||||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.") | |||||
super(SeqLabelTester, self).__init__(**test_args) | super(SeqLabelTester, self).__init__(**test_args) | ||||
class ClassificationTester(Tester): | class ClassificationTester(Tester): | ||||
def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
test_args.update({"task": "text_classify"}) | |||||
print( | print( | ||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") | |||||
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||||
super(ClassificationTester, self).__init__(**test_args) | super(ClassificationTester, self).__init__(**test_args) |
@@ -8,6 +8,7 @@ from tensorboardX import SummaryWriter | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
from fastNLP.core.metrics import Evaluator | |||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | from fastNLP.core.tester import SeqLabelTester, ClassificationTester | ||||
@@ -43,21 +44,20 @@ class Trainer(object): | |||||
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | ||||
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
"loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
"evaluator": Evaluator() | |||||
} | } | ||||
""" | """ | ||||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
""" | """ | ||||
required_args = {"task" # one of ("seq_label", "text_classify") | |||||
} | |||||
required_args = {} | |||||
for req_key in required_args: | for req_key in required_args: | ||||
if req_key not in kwargs: | if req_key not in kwargs: | ||||
logger.error("Trainer lacks argument {}".format(req_key)) | logger.error("Trainer lacks argument {}".format(req_key)) | ||||
raise ValueError("Trainer lacks argument {}".format(req_key)) | raise ValueError("Trainer lacks argument {}".format(req_key)) | ||||
self._task = kwargs["task"] | |||||
for key in default_args: | for key in default_args: | ||||
if key in kwargs: | if key in kwargs: | ||||
@@ -86,6 +86,7 @@ class Trainer(object): | |||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | ||||
self._optimizer = None | self._optimizer = None | ||||
self._optimizer_proto = default_args["optimizer"] | self._optimizer_proto = default_args["optimizer"] | ||||
self._evaluator = default_args["evaluator"] | |||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | ||||
self._graph_summaried = False | self._graph_summaried = False | ||||
self._best_accuracy = 0.0 | self._best_accuracy = 0.0 | ||||
@@ -106,9 +107,8 @@ class Trainer(object): | |||||
# define Tester over dev data | # define Tester over dev data | ||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
"use_cuda": self.use_cuda, "print_every_step": 0} | |||||
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
"use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||||
validator = self._create_validator(default_valid_args) | validator = self._create_validator(default_valid_args) | ||||
logger.info("validator defined as {}".format(str(validator))) | logger.info("validator defined as {}".format(str(validator))) | ||||
@@ -229,18 +229,9 @@ class Trainer(object): | |||||
self._optimizer.step() | self._optimizer.step() | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
if self._task == "seq_label": | |||||
y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
elif self._task == "text_classify" or self._task == "language_model": | |||||
y = network(x["word_seq"]) | |||||
else: | |||||
raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
y = network(**x) | |||||
if not self._graph_summaried: | if not self._graph_summaried: | ||||
if self._task == "seq_label": | |||||
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | |||||
elif self._task == "text_classify" or self._task == "language_model": | |||||
self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | |||||
# self._summary_writer.add_graph(network, x, verbose=False) | |||||
self._graph_summaried = True | self._graph_summaried = True | ||||
return y | return y | ||||
@@ -261,13 +252,9 @@ class Trainer(object): | |||||
:param truth: ground truth label vector | :param truth: ground truth label vector | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
if "label_seq" in truth: | |||||
truth = truth["label_seq"] | |||||
elif "label" in truth: | |||||
truth = truth["label"] | |||||
truth = truth.view((-1,)) | |||||
else: | |||||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
if len(truth) > 1: | |||||
raise NotImplementedError("Not ready to handle multi-labels.") | |||||
truth = list(truth.values())[0] if len(truth) > 0 else None | |||||
return self._loss_func(predict, truth) | return self._loss_func(predict, truth) | ||||
def define_loss(self): | def define_loss(self): | ||||
@@ -278,8 +265,8 @@ class Trainer(object): | |||||
These two losses cannot be defined at the same time. | These two losses cannot be defined at the same time. | ||||
Trainer does not handle loss definition or choose default losses. | Trainer does not handle loss definition or choose default losses. | ||||
""" | """ | ||||
if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
# if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
if hasattr(self._model, "loss"): | if hasattr(self._model, "loss"): | ||||
self._loss_func = self._model.loss | self._loss_func = self._model.loss | ||||
@@ -322,9 +309,8 @@ class SeqLabelTrainer(Trainer): | |||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
kwargs.update({"task": "seq_label"}) | |||||
print( | print( | ||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.") | |||||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") | |||||
super(SeqLabelTrainer, self).__init__(**kwargs) | super(SeqLabelTrainer, self).__init__(**kwargs) | ||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
@@ -335,9 +321,8 @@ class ClassificationTrainer(Trainer): | |||||
"""Trainer for text classification.""" | """Trainer for text classification.""" | ||||
def __init__(self, **train_args): | def __init__(self, **train_args): | ||||
train_args.update({"task": "text_classify"}) | |||||
print( | print( | ||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.") | |||||
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.") | |||||
super(ClassificationTrainer, self).__init__(**train_args) | super(ClassificationTrainer, self).__init__(**train_args) | ||||
def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
@@ -54,7 +54,7 @@ class Vocabulary(object): | |||||
def update(self, word): | def update(self, word): | ||||
"""add word or list of words into Vocabulary | """add word or list of words into Vocabulary | ||||
:param word: a list of str or str | |||||
:param word: a list of string or a single string | |||||
""" | """ | ||||
if not isinstance(word, str) and isiterable(word): | if not isinstance(word, str) and isiterable(word): | ||||
# it's a nested list | # it's a nested list | ||||
@@ -1,23 +1,18 @@ | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
"""docstring for BaseLoader""" | |||||
def __init__(self, data_path): | |||||
def __init__(self): | |||||
super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
self.data_path = data_path | |||||
def load(self): | |||||
""" | |||||
:return: string | |||||
""" | |||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
text = f.read() | |||||
return text | |||||
def load_lines(self): | |||||
with open(self.data_path, "r", encoding="utf=8") as f: | |||||
def load_lines(self, data_path): | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = f.readlines() | text = f.readlines() | ||||
return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
text = f.readlines() | |||||
return [[word for word in sent.strip()] for sent in text] | |||||
class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
""" | """ | ||||
@@ -8,9 +8,9 @@ from fastNLP.loader.base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
"""loader for configuration files""" | """loader for configuration files""" | ||||
def __int__(self, data_name, data_path): | |||||
super(ConfigLoader, self).__init__(data_path) | |||||
self.config = self.parse(super(ConfigLoader, self).load()) | |||||
def __int__(self, data_path): | |||||
super(ConfigLoader, self).__init__() | |||||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
@staticmethod | @staticmethod | ||||
def parse(string): | def parse(string): | ||||
@@ -3,14 +3,17 @@ import os | |||||
from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
class DatasetLoader(BaseLoader): | |||||
class DataSetLoader(BaseLoader): | |||||
""""loader for data sets""" | """"loader for data sets""" | ||||
def __init__(self, data_path): | |||||
super(DatasetLoader, self).__init__(data_path) | |||||
def __init__(self): | |||||
super(DataSetLoader, self).__init__() | |||||
def load(self, path): | |||||
raise NotImplementedError | |||||
class POSDatasetLoader(DatasetLoader): | |||||
class POSDataSetLoader(DataSetLoader): | |||||
"""Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
In these datasets, each line are divided by '\t' | In these datasets, each line are divided by '\t' | ||||
@@ -31,16 +34,10 @@ class POSDatasetLoader(DatasetLoader): | |||||
to label5. | to label5. | ||||
""" | """ | ||||
def __init__(self, data_path): | |||||
super(POSDatasetLoader, self).__init__(data_path) | |||||
def load(self): | |||||
assert os.path.exists(self.data_path) | |||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
line = f.read() | |||||
return line | |||||
def __init__(self): | |||||
super(POSDataSetLoader, self).__init__() | |||||
def load_lines(self): | |||||
def load(self, data_path): | |||||
""" | """ | ||||
:return data: three-level list | :return data: three-level list | ||||
[ | [ | ||||
@@ -49,7 +46,7 @@ class POSDatasetLoader(DatasetLoader): | |||||
... | ... | ||||
] | ] | ||||
""" | """ | ||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | lines = f.readlines() | ||||
return self.parse(lines) | return self.parse(lines) | ||||
@@ -79,15 +76,15 @@ class POSDatasetLoader(DatasetLoader): | |||||
return data | return data | ||||
class TokenizeDatasetLoader(DatasetLoader): | |||||
class TokenizeDataSetLoader(DataSetLoader): | |||||
""" | """ | ||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
""" | """ | ||||
def __init__(self, data_path): | |||||
super(TokenizeDatasetLoader, self).__init__(data_path) | |||||
def __init__(self): | |||||
super(TokenizeDataSetLoader, self).__init__() | |||||
def load_pku(self, max_seq_len=32): | |||||
def load(self, data_path, max_seq_len=32): | |||||
""" | """ | ||||
load pku dataset for Chinese word segmentation | load pku dataset for Chinese word segmentation | ||||
CWS (Chinese Word Segmentation) pku training dataset format: | CWS (Chinese Word Segmentation) pku training dataset format: | ||||
@@ -104,7 +101,7 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
:return: three-level lists | :return: three-level lists | ||||
""" | """ | ||||
assert isinstance(max_seq_len, int) and max_seq_len > 0 | assert isinstance(max_seq_len, int) and max_seq_len > 0 | ||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
sentences = f.readlines() | sentences = f.readlines() | ||||
data = [] | data = [] | ||||
for sent in sentences: | for sent in sentences: | ||||
@@ -135,15 +132,15 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
return data | return data | ||||
class ClassDatasetLoader(DatasetLoader): | |||||
class ClassDataSetLoader(DataSetLoader): | |||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
def __init__(self, data_path): | |||||
super(ClassDatasetLoader, self).__init__(data_path) | |||||
def __init__(self): | |||||
super(ClassDataSetLoader, self).__init__() | |||||
def load(self): | |||||
assert os.path.exists(self.data_path) | |||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
def load(self, data_path): | |||||
assert os.path.exists(data_path) | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | lines = f.readlines() | ||||
return self.parse(lines) | return self.parse(lines) | ||||
@@ -169,21 +166,21 @@ class ClassDatasetLoader(DatasetLoader): | |||||
return dataset | return dataset | ||||
class ConllLoader(DatasetLoader): | |||||
class ConllLoader(DataSetLoader): | |||||
"""loader for conll format files""" | """loader for conll format files""" | ||||
def __int__(self, data_path): | def __int__(self, data_path): | ||||
""" | """ | ||||
:param str data_path: the path to the conll data set | :param str data_path: the path to the conll data set | ||||
""" | """ | ||||
super(ConllLoader, self).__init__(data_path) | |||||
self.data_set = self.parse(self.load()) | |||||
super(ConllLoader, self).__init__() | |||||
self.data_set = self.parse(self.load(data_path)) | |||||
def load(self): | |||||
def load(self, data_path): | |||||
""" | """ | ||||
:return: list lines: all lines in a conll file | :return: list lines: all lines in a conll file | ||||
""" | """ | ||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | lines = f.readlines() | ||||
return lines | return lines | ||||
@@ -207,20 +204,21 @@ class ConllLoader(DatasetLoader): | |||||
return sentences | return sentences | ||||
class LMDatasetLoader(DatasetLoader): | |||||
class LMDataSetLoader(DataSetLoader): | |||||
"""Language Model Dataset Loader | """Language Model Dataset Loader | ||||
This loader produces data for language model training in a supervised way. | This loader produces data for language model training in a supervised way. | ||||
That means it has X and Y. | That means it has X and Y. | ||||
""" | """ | ||||
def __init__(self, data_path): | |||||
super(LMDatasetLoader, self).__init__(data_path) | |||||
def load(self): | |||||
if not os.path.exists(self.data_path): | |||||
raise FileNotFoundError("file {} not found.".format(self.data_path)) | |||||
with open(self.data_path, "r", encoding="utf=8") as f: | |||||
def __init__(self): | |||||
super(LMDataSetLoader, self).__init__() | |||||
def load(self, data_path): | |||||
if not os.path.exists(data_path): | |||||
raise FileNotFoundError("file {} not found.".format(data_path)) | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = " ".join(f.readlines()) | text = " ".join(f.readlines()) | ||||
tokens = text.strip().split() | tokens = text.strip().split() | ||||
return self.sentence_cut(tokens) | return self.sentence_cut(tokens) | ||||
@@ -237,16 +235,17 @@ class LMDatasetLoader(DatasetLoader): | |||||
data_set.append([x, y]) | data_set.append([x, y]) | ||||
return data_set | return data_set | ||||
class PeopleDailyCorpusLoader(DatasetLoader): | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | |||||
""" | """ | ||||
People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
""" | """ | ||||
def __init__(self, data_path): | |||||
super(PeopleDailyCorpusLoader, self).__init__(data_path) | |||||
def __init__(self): | |||||
super(PeopleDailyCorpusLoader, self).__init__() | |||||
def load(self): | |||||
with open(self.data_path, "r", encoding="utf-8") as f: | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
sents = f.readlines() | sents = f.readlines() | ||||
pos_tag_examples = [] | pos_tag_examples = [] | ||||
@@ -36,11 +36,13 @@ class SeqLabeling(BaseModel): | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
self.mask = None | self.mask = None | ||||
def forward(self, word_seq, word_seq_origin_len): | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | ||||
:return y: [batch_size, mex_len, tag_size] | |||||
:param truth: LongTensor, [batch_size, max_len] | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||||
If truth is not None, return loss, a scalar. Used in training. | |||||
""" | """ | ||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
@@ -50,9 +52,16 @@ class SeqLabeling(BaseModel): | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return x | |||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) | |||||
def loss(self, x, y): | def loss(self, x, y): | ||||
""" Since the loss has been computed in forward(), this function simply returns x.""" | |||||
return x | |||||
def _internal_loss(self, x, y): | |||||
""" | """ | ||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: Tensor, [batch_size, max_len, tag_size] | :param x: Tensor, [batch_size, max_len, tag_size] | ||||
@@ -74,12 +83,19 @@ class SeqLabeling(BaseModel): | |||||
mask = mask.to(x) | mask = mask.to(x) | ||||
return mask | return mask | ||||
def prediction(self, x): | |||||
def decode(self, x, pad=True): | |||||
""" | """ | ||||
:param x: FloatTensor, [batch_size, max_len, tag_size] | :param x: FloatTensor, [batch_size, max_len, tag_size] | ||||
:param pad: pad the output sequence to equal lengths | |||||
:return prediction: list of [decode path(list)] | :return prediction: list of [decode path(list)] | ||||
""" | """ | ||||
max_len = x.shape[1] | |||||
tag_seq = self.Crf.viterbi_decode(x, self.mask) | tag_seq = self.Crf.viterbi_decode(x, self.mask) | ||||
# pad prediction to equal length | |||||
if pad is True: | |||||
for pred in tag_seq: | |||||
if len(pred) < max_len: | |||||
pred += [0] * (max_len - len(pred)) | |||||
return tag_seq | return tag_seq | ||||
@@ -106,11 +122,12 @@ class AdvSeqLabel(SeqLabeling): | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
def forward(self, word_seq, word_seq_origin_len): | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
:param word_seq_origin_len: list of int. | :param word_seq_origin_len: list of int. | ||||
:return y: [batch_size, mex_len, tag_size] | |||||
:param truth: LongTensor, [batch_size, max_len] | |||||
:return y: | |||||
""" | """ | ||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
@@ -129,4 +146,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
x = x.view(batch_size, max_len, -1) | x = x.view(batch_size, max_len, -1) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return x | |||||
if truth is not None: | |||||
return self._internal_loss(x, truth) | |||||
else: | |||||
return self.decode(x) |
@@ -55,14 +55,13 @@ class SelfAttention(nn.Module): | |||||
input = input.contiguous() | input = input.contiguous() | ||||
size = input.size() # [bsz, len, nhid] | size = input.size() # [bsz, len, nhid] | ||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | ||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||||
attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||||
attention = self.ws2(y1).transpose(1, | |||||
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | ||||
attention = F.softmax(attention,2) # [baz ,hop, len] | |||||
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid] | |||||
attention = F.softmax(attention, 2) # [baz ,hop, len] | |||||
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid] |
@@ -1,14 +1,14 @@ | |||||
from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
from fastNLP.core.preprocess import Preprocessor | from fastNLP.core.preprocess import Preprocessor | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
from fastNLP.loader.dataset_loader import LMDatasetLoader | |||||
from fastNLP.loader.dataset_loader import LMDataSetLoader | |||||
from fastNLP.models.char_language_model import CharLM | from fastNLP.models.char_language_model import CharLM | ||||
PICKLE = "./save/" | PICKLE = "./save/" | ||||
def train(): | def train(): | ||||
loader = LMDatasetLoader("./train.txt") | |||||
loader = LMDataSetLoader() | |||||
train_data = loader.load() | train_data = loader.load() | ||||
pre = Preprocessor(label_is_seq=True, share_vocab=True) | pre = Preprocessor(label_is_seq=True, share_vocab=True) | ||||
@@ -4,7 +4,7 @@ from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader | from fastNLP.loader.config_loader import ConfigLoader | ||||
from fastNLP.loader.config_loader import ConfigSection | from fastNLP.loader.config_loader import ConfigSection | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
@@ -5,7 +5,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
@@ -66,8 +66,8 @@ def train(): | |||||
ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | ||||
# Data Loader | # Data Loader | ||||
loader = TokenizeDatasetLoader(cws_data_path) | |||||
train_data = loader.load_pku() | |||||
loader = TokenizeDataSetLoader() | |||||
train_data = loader.load() | |||||
# Preprocessor | # Preprocessor | ||||
preprocessor = SeqLabelPreprocess() | preprocessor = SeqLabelPreprocess() | ||||
@@ -66,7 +66,7 @@ def train(): | |||||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | ||||
# Data Loader | # Data Loader | ||||
loader = PeopleDailyCorpusLoader(pos_tag_data_path) | |||||
loader = PeopleDailyCorpusLoader() | |||||
train_data, _ = loader.load() | train_data, _ = loader.load() | ||||
# Preprocessor | # Preprocessor | ||||
@@ -1,6 +1,6 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, LMDatasetLoader, TokenizeDatasetLoader, \ | |||||
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||||
PeopleDailyCorpusLoader, ConllLoader | PeopleDailyCorpusLoader, ConllLoader | ||||
@@ -8,29 +8,29 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_case_1(self): | def test_case_1(self): | ||||
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | ||||
lines = data.split("\n") | lines = data.split("\n") | ||||
answer = POSDatasetLoader.parse(lines) | |||||
answer = POSDataSetLoader.parse(lines) | |||||
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | ||||
self.assertListEqual(answer, truth, "POS Dataset Loader") | self.assertListEqual(answer, truth, "POS Dataset Loader") | ||||
def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
data = loader.load_pku(max_seq_len=32) | |||||
print("pass TokenizeDatasetLoader test!") | |||||
loader = TokenizeDataSetLoader() | |||||
data = loader.load("test/data_for_tests/", max_seq_len=32) | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
loader = POSDatasetLoader("./test/data_for_tests/people.txt") | |||||
loader = POSDataSetLoader() | |||||
data = loader.load() | data = loader.load() | ||||
datas = loader.load_lines() | datas = loader.load_lines() | ||||
print("pass POSDatasetLoader test!") | |||||
print("pass POSDataSetLoader test!") | |||||
def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
loader = LMDataSetLoader() | |||||
data = loader.load() | data = loader.load() | ||||
datas = loader.load_lines() | datas = loader.load_lines() | ||||
print("pass TokenizeDatasetLoader test!") | |||||
print("pass TokenizeDataSetLoader test!") | |||||
def test_PeopleDailyCorpusLoader(self): | def test_PeopleDailyCorpusLoader(self): | ||||
loader = PeopleDailyCorpusLoader("./test/data_for_tests/people_daily_raw.txt") | |||||
loader = PeopleDailyCorpusLoader() | |||||
_, _ = loader.load() | _, _ = loader.load() | ||||
def test_ConllLoader(self): | def test_ConllLoader(self): | ||||
@@ -4,14 +4,16 @@ sys.path.append("..") | |||||
import argparse | import argparse | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.loader.dataset_loader import BaseLoader | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | ||||
@@ -33,24 +35,27 @@ data_infer_path = args.infer | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args}) | |||||
ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
test_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | |||||
word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["vocab_size"] = len(word_vocab) | |||||
test_args["num_classes"] = len(label_vocab) | |||||
print("vocabularies loaded") | |||||
# Define the same model | # Define the same model | ||||
model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
print("model defined") | |||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | ||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | |||||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||||
print("data set prepared") | |||||
# Inference interface | # Inference interface | ||||
infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
@@ -65,24 +70,18 @@ def train_and_test(): | |||||
# Config Loader | # Config Loader | ||||
trainer_args = ConfigSection() | trainer_args = ConfigSection() | ||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_dir, { | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
# Data Loader | |||||
pos_loader = POSDatasetLoader(data_path) | |||||
train_data = pos_loader.load_lines() | |||||
# Preprocessor | |||||
p = SeqLabelPreprocess() | |||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
model_args["vocab_size"] = p.vocab_size | |||||
model_args["num_classes"] = p.num_classes | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load(data_path) | |||||
train_set, dev_set = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(data_set.word_vocab) | |||||
model_args["num_classes"] = len(data_set.label_vocab) | |||||
# Trainer: two definition styles | |||||
# 1 | |||||
# trainer = SeqLabelTrainer(trainer_args.data) | |||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
# 2 | |||||
trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||
batch_size=trainer_args["batch_size"], | batch_size=trainer_args["batch_size"], | ||||
@@ -98,7 +97,7 @@ def train_and_test(): | |||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
# Start training | # Start training | ||||
trainer.train(model, data_train, data_dev) | |||||
trainer.train(model, train_set, dev_set) | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
@@ -106,7 +105,9 @@ def train_and_test(): | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
del model, trainer, pos_loader | |||||
del model, trainer | |||||
change_field_is_target(dev_set, "truth", True) | |||||
# Define the same model | # Define the same model | ||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
@@ -117,27 +118,21 @@ def train_and_test(): | |||||
# Load test configuration | # Load test configuration | ||||
tester_args = ConfigSection() | tester_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(save_output=False, | |||||
save_loss=True, | |||||
save_best_dev=False, | |||||
batch_size=4, | |||||
tester = SeqLabelTester(batch_size=4, | |||||
use_cuda=False, | use_cuda=False, | ||||
pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
model_name="seq_label_in_test.pkl", | model_name="seq_label_in_test.pkl", | ||||
print_every_step=1 | |||||
evaluator=SeqLabelEvaluator() | |||||
) | ) | ||||
# Start testing with validation data | # Start testing with validation data | ||||
tester.test(model, data_dev) | |||||
# print test results | |||||
print(tester.show_metrics()) | |||||
tester.test(model, dev_set) | |||||
print("model tested!") | print("model tested!") | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | train_and_test() | ||||
# infer() | |||||
infer() |
@@ -1,11 +1,13 @@ | |||||
import os | import os | ||||
from fastNLP.core.predictor import Predictor | |||||
from fastNLP.core.preprocess import Preprocessor, load_pickle | |||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -19,12 +21,12 @@ config_path = "test/data_for_tests/config" | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
# Define the same model | # Define the same model | ||||
@@ -34,31 +36,29 @@ def infer(): | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | |||||
# Load infer data | |||||
infer_data = SeqLabelDataSet(loader=BaseLoader()) | |||||
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
# Inference interface | |||||
infer = Predictor(pickle_path, "seq_label") | |||||
# inference | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
print(results) | print(results) | ||||
def train_test(): | def train_test(): | ||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args}) | |||||
ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||||
# Data Loader | |||||
loader = TokenizeDatasetLoader(cws_data_path) | |||||
train_data = loader.load_pku() | |||||
# define dataset | |||||
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader()) | |||||
data_train.load(cws_data_path) | |||||
train_args["vocab_size"] = len(data_train.word_vocab) | |||||
train_args["num_classes"] = len(data_train.label_vocab) | |||||
# Preprocessor | |||||
p = Preprocessor(label_is_seq=True) | |||||
data_train = p.run(train_data, pickle_path=pickle_path) | |||||
train_args["vocab_size"] = p.vocab_size | |||||
train_args["num_classes"] = p.num_classes | |||||
save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl") | |||||
# Trainer | # Trainer | ||||
trainer = SeqLabelTrainer(**train_args.data) | trainer = SeqLabelTrainer(**train_args.data) | ||||
@@ -73,7 +73,7 @@ def train_test(): | |||||
saver = ModelSaver("./save/saved_model.pkl") | saver = ModelSaver("./save/saved_model.pkl") | ||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
del model, trainer, loader | |||||
del model, trainer | |||||
# Define the same model | # Define the same model | ||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
@@ -83,17 +83,16 @@ def train_test(): | |||||
# Load test configuration | # Load test configuration | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||||
ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
test_args["evaluator"] = SeqLabelEvaluator() | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(**test_args.data) | tester = SeqLabelTester(**test_args.data) | ||||
# Start testing | # Start testing | ||||
change_field_is_target(data_train, "truth", True) | |||||
tester.test(model, data_train) | tester.test(model, data_train) | ||||
# print test results | |||||
print(tester.show_metrics()) | |||||
def test(): | def test(): | ||||
os.makedirs("save", exist_ok=True) | os.makedirs("save", exist_ok=True) | ||||
@@ -1,11 +1,12 @@ | |||||
import os | import os | ||||
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
from fastNLP.core.preprocess import save_pickle | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -21,18 +22,17 @@ def test_training(): | |||||
# Config Loader | # Config Loader | ||||
trainer_args = ConfigSection() | trainer_args = ConfigSection() | ||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
ConfigLoader("_").load_config(config_dir, { | |||||
ConfigLoader().load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
# Data Loader | |||||
pos_loader = POSDatasetLoader(data_path) | |||||
train_data = pos_loader.load_lines() | |||||
data_set = SeqLabelDataSet() | |||||
data_set.load(data_path) | |||||
data_train, data_dev = data_set.split(0.3, shuffle=True) | |||||
model_args["vocab_size"] = len(data_set.word_vocab) | |||||
model_args["num_classes"] = len(data_set.label_vocab) | |||||
# Preprocessor | |||||
p = SeqLabelPreprocess() | |||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
model_args["vocab_size"] = p.vocab_size | |||||
model_args["num_classes"] = p.num_classes | |||||
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||
@@ -55,7 +55,7 @@ def test_training(): | |||||
saver = ModelSaver(os.path.join(pickle_path, model_name)) | saver = ModelSaver(os.path.join(pickle_path, model_name)) | ||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
del model, trainer, pos_loader | |||||
del model, trainer | |||||
# Define the same model | # Define the same model | ||||
model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
@@ -65,21 +65,16 @@ def test_training(): | |||||
# Load test configuration | # Load test configuration | ||||
tester_args = ConfigSection() | tester_args = ConfigSection() | ||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(save_output=False, | |||||
save_loss=True, | |||||
save_best_dev=False, | |||||
batch_size=4, | |||||
tester = SeqLabelTester(batch_size=4, | |||||
use_cuda=False, | use_cuda=False, | ||||
pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
model_name="seq_label_in_test.pkl", | model_name="seq_label_in_test.pkl", | ||||
print_every_step=1 | |||||
evaluator=SeqLabelEvaluator() | |||||
) | ) | ||||
# Start testing with validation data | # Start testing with validation data | ||||
change_field_is_target(data_dev, "truth", True) | |||||
tester.test(model, data_dev) | tester.test(model, data_dev) | ||||
loss, accuracy = tester.metrics | |||||
assert 0 < accuracy < 1 |
@@ -9,13 +9,14 @@ sys.path.append("..") | |||||
from fastNLP.core.predictor import ClassificationInfer | from fastNLP.core.predictor import ClassificationInfer | ||||
from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.preprocess import ClassPreprocess | |||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
from fastNLP.core.dataset import TextClassifyDataSet | |||||
from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | ||||
@@ -34,21 +35,18 @@ config_dir = args.config | |||||
def infer(): | def infer(): | ||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader(train_data_dir) | |||||
data = ds_loader.load() | |||||
unlabeled_data = [x[0] for x in data] | |||||
word_vocab = load_pickle(save_dir, "word2id.pkl") | |||||
label_vocab = load_pickle(save_dir, "label2id.pkl") | |||||
print("vocabulary size:", len(word_vocab)) | |||||
print("number of classes:", len(label_vocab)) | |||||
# pre-process data | |||||
pre = ClassPreprocess() | |||||
data = pre.run(data, pickle_path=save_dir) | |||||
print("vocabulary size:", pre.vocab_size) | |||||
print("number of classes:", pre.num_classes) | |||||
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||||
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
# TODO: load from config file | |||||
model_args["vocab_size"] = pre.vocab_size | |||||
model_args["num_classes"] = pre.num_classes | |||||
# ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
model_args["vocab_size"] = len(word_vocab) | |||||
model_args["num_classes"] = len(label_vocab) | |||||
ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
# construct model | # construct model | ||||
print("Building model...") | print("Building model...") | ||||
@@ -59,7 +57,7 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
infer = ClassificationInfer(pickle_path=save_dir) | infer = ClassificationInfer(pickle_path=save_dir) | ||||
results = infer.predict(cnn, unlabeled_data) | |||||
results = infer.predict(cnn, infer_data) | |||||
print(results) | print(results) | ||||
@@ -69,32 +67,23 @@ def train(): | |||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader(train_data_dir) | |||||
data = ds_loader.load() | |||||
print(data[0]) | |||||
data = TextClassifyDataSet(loader=ClassDataSetLoader()) | |||||
data.load(train_data_dir) | |||||
# pre-process data | |||||
pre = ClassPreprocess() | |||||
data_train = pre.run(data, pickle_path=save_dir) | |||||
print("vocabulary size:", pre.vocab_size) | |||||
print("number of classes:", pre.num_classes) | |||||
print("vocabulary size:", len(data.word_vocab)) | |||||
print("number of classes:", len(data.label_vocab)) | |||||
save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||||
save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||||
model_args["num_classes"] = pre.num_classes | |||||
model_args["vocab_size"] = pre.vocab_size | |||||
model_args["num_classes"] = len(data.label_vocab) | |||||
model_args["vocab_size"] = len(data.word_vocab) | |||||
# construct model | # construct model | ||||
print("Building model...") | print("Building model...") | ||||
model = CNNText(model_args) | model = CNNText(model_args) | ||||
# ConfigSaver().save_config(config_dir, {"text_class_model": model_args}) | |||||
# train | # train | ||||
print("Training...") | print("Training...") | ||||
# 1 | |||||
# trainer = ClassificationTrainer(train_args) | |||||
# 2 | |||||
trainer = ClassificationTrainer(epochs=train_args["epochs"], | trainer = ClassificationTrainer(epochs=train_args["epochs"], | ||||
batch_size=train_args["batch_size"], | batch_size=train_args["batch_size"], | ||||
validate=train_args["validate"], | validate=train_args["validate"], | ||||
@@ -104,7 +93,7 @@ def train(): | |||||
model_name=model_name, | model_name=model_name, | ||||
loss=Loss("cross_entropy"), | loss=Loss("cross_entropy"), | ||||
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | ||||
trainer.train(model, data_train) | |||||
trainer.train(model, data) | |||||
print("Training finished!") | print("Training finished!") | ||||
@@ -115,4 +104,4 @@ def train(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train() | train() | ||||
# infer() | |||||
infer() |