Browse Source

Merge Preprocessor and DataSet

tags/v0.1.0^2
FengZiYjun 6 years ago
parent
commit
0b86d7cf2b
27 changed files with 514 additions and 420 deletions
  1. +2
    -2
      docs/source/user/quickstart.rst
  2. +2
    -2
      examples/readme_example.py
  3. +4
    -17
      fastNLP/core/batch.py
  4. +176
    -2
      fastNLP/core/dataset.py
  5. +6
    -2
      fastNLP/core/field.py
  6. +5
    -2
      fastNLP/core/instance.py
  7. +12
    -1
      fastNLP/core/loss.py
  8. +45
    -0
      fastNLP/core/metrics.py
  9. +29
    -43
      fastNLP/core/predictor.py
  10. +0
    -1
      fastNLP/core/preprocess.py
  11. +20
    -111
      fastNLP/core/tester.py
  12. +16
    -31
      fastNLP/core/trainer.py
  13. +1
    -1
      fastNLP/core/vocabulary.py
  14. +8
    -13
      fastNLP/loader/base_loader.py
  15. +3
    -3
      fastNLP/loader/config_loader.py
  16. +41
    -42
      fastNLP/loader/dataset_loader.py
  17. +27
    -7
      fastNLP/models/sequence_modeling.py
  18. +6
    -7
      fastNLP/modules/aggregator/self_attention.py
  19. +2
    -2
      reproduction/Char-aware_NLM/main.py
  20. +1
    -1
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  21. +3
    -3
      reproduction/chinese_word_segment/run.py
  22. +1
    -1
      reproduction/pos_tag_model/train_pos_tag.py
  23. +10
    -10
      test/loader/test_dataset_loader.py
  24. +31
    -36
      test/model/seq_labeling.py
  25. +24
    -25
      test/model/test_cws.py
  26. +16
    -21
      test/model/test_seq_label.py
  27. +23
    -34
      test/model/text_classify.py

+ 2
- 2
docs/source/user/quickstart.rst View File

@@ -18,7 +18,7 @@ pre-processing data, constructing model and training model.
from fastNLP.modules import aggregation
from fastNLP.modules import decoder

from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.inference import ClassificationInfer
@@ -50,7 +50,7 @@ pre-processing data, constructing model and training model.
train_path = 'test/data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDatasetLoader("train", train_path)
ds_loader = ClassDataSetLoader("train", train_path)
data = ds_loader.load()

# pre-process dataset


+ 2
- 2
examples/readme_example.py View File

@@ -3,7 +3,7 @@ from fastNLP.core.optimizer import Optimizer
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregator
from fastNLP.modules import decoder
@@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model
train_path = './data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDatasetLoader(train_path)
ds_loader = ClassDataSetLoader()
data = ds_loader.load()

# pre-process dataset


+ 4
- 17
fastNLP/core/batch.py View File

@@ -17,7 +17,7 @@ class Batch(object):
:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param use_cuda: bool, whetjher to use GPU
:param use_cuda: bool, whether to use GPU

"""
self.dataset = dataset
@@ -37,15 +37,12 @@ class Batch(object):
"""

:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
batch_x also contains an item (str: list of int) about origin lengths,
which means ("field_name_origin_len": origin lengths).
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})

batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
The names of fields are defined in preprocessor's convert_to_dataset method.

"""
if self.curidx >= len(self.idx_list):
@@ -54,10 +51,9 @@ class Batch(object):
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
origin_lengths = {field_name: field_length[self.curidx: endidx]
for field_name, field_length in self.lengths.items()}

batch_x, batch_y = defaultdict(list), defaultdict(list)

# transform index to tensor and do padding for sequences
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
@@ -65,8 +61,7 @@ class Batch(object):
for name, tensor in y.items():
batch_y[name].append(tensor)

batch_origin_length = {}
# combine instances into a batch
# combine instances to form a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
@@ -74,14 +69,6 @@ class Batch(object):
else:
batch[name] = torch.stack(tensor_list, dim=0)

# add origin lengths in batch_x
for name, tensor in batch_x.items():
if self.use_cuda:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
else:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
batch_x.update(batch_origin_length)

self.curidx += endidx
return batch_x, batch_y


+ 176
- 2
fastNLP/core/dataset.py View File

@@ -1,7 +1,11 @@
import random
from collections import defaultdict
from copy import deepcopy

from fastNLP.core.field import TextField
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader


def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
@@ -65,7 +69,8 @@ class DataSet(list):
"""A DataSet object is a list of Instance objects.

"""
def __init__(self, name="", instances=None):

def __init__(self, name="", instances=None, loader=None):
"""

:param name: str, the name of the dataset. (default: "")
@@ -76,6 +81,7 @@ class DataSet(list):
self.name = name
if instances is not None:
self.extend(instances)
self.dataset_loader = loader

def index_all(self, vocab):
for ins in self:
@@ -109,3 +115,171 @@ class DataSet(list):
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths

def convert(self, data):
"""Convert lists of strings into Instances with Fields"""
raise NotImplementedError

def convert_with_vocabs(self, data, vocabs):
"""Convert lists of strings into Instances with Fields, using existing Vocabulary. Useful in predicting."""
raise NotImplementedError

def convert_for_infer(self, data, vocabs):
"""Convert lists of strings into Instances with Fields."""

def load(self, data_path, vocabs=None, infer=False):
"""Load data from the given files.

:param data_path: str, the path to the data
:param infer: bool. If True, there is no label information in the data. Default: False.
:param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.

"""
raw_data = self.dataset_loader.load(data_path)
if infer is True:
self.convert_for_infer(raw_data, vocabs)
else:
if vocabs is not None:
self.convert_with_vocabs(raw_data, vocabs)
else:
self.convert(raw_data)

def split(self, ratio, shuffle=True):
"""Train/dev splitting

:param ratio: float, between 0 and 1. The ratio of development set in origin data set.
:param shuffle: bool, whether shuffle the data set before splitting. Default: True.
:return train_set: a DataSet object, representing the training set
dev_set: a DataSet object, representing the validation set

"""
assert 0 < ratio < 1
if shuffle:
random.shuffle(self)
split_idx = int(len(self) * ratio)
dev_set = deepcopy(self)
train_set = deepcopy(self)
del train_set[:split_idx]
del dev_set[split_idx:]
return train_set, dev_set


class SeqLabelDataSet(DataSet):
def __init__(self, instances=None, loader=POSDataSetLoader()):
super(SeqLabelDataSet, self).__init__(name="", instances=instances, loader=loader)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary()

def convert(self, data):
"""Convert lists of strings into Instances with Fields.

:param data: 3-level lists. Entries are strings.
"""
for example in data:
word_seq, label_seq = example[0], example[1]
# list, list
self.word_vocab.update(word_seq)
self.label_vocab.update(label_seq)
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
y = TextField(label_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("truth", y)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", self.word_vocab)
self.index_field("truth", self.label_vocab)
# no need to index "word_seq_origin_len"

def convert_with_vocabs(self, data, vocabs):
for example in data:
word_seq, label_seq = example[0], example[1]
# list, list
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
y = TextField(label_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("truth", y)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
self.index_field("truth", vocabs["label_vocab"])
# no need to index "word_seq_origin_len"

def convert_for_infer(self, data, vocabs):
for word_seq in data:
# list
x = TextField(word_seq, is_target=False)
x_len = LabelField(len(word_seq), is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("word_seq_origin_len", x_len)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
# no need to index "word_seq_origin_len"


class TextClassifyDataSet(DataSet):
def __init__(self, instances=None, loader=ClassDataSetLoader()):
super(TextClassifyDataSet, self).__init__(name="", instances=instances, loader=loader)
self.word_vocab = Vocabulary()
self.label_vocab = Vocabulary(need_default=False)

def convert(self, data):
for example in data:
word_seq, label = example[0], example[1]
# list, str
self.word_vocab.update(word_seq)
self.label_vocab.update(label)
x = TextField(word_seq, is_target=False)
y = LabelField(label, is_target=True)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("label", y)
self.append(instance)
self.index_field("word_seq", self.word_vocab)
self.index_field("label", self.label_vocab)

def convert_with_vocabs(self, data, vocabs):
for example in data:
word_seq, label = example[0], example[1]
# list, str
x = TextField(word_seq, is_target=False)
y = LabelField(label, is_target=True)
instance = Instance()
instance.add_field("word_seq", x)
instance.add_field("label", y)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])
self.index_field("label", vocabs["label_vocab"])

def convert_for_infer(self, data, vocabs):
for word_seq in data:
# list
x = TextField(word_seq, is_target=False)
instance = Instance()
instance.add_field("word_seq", x)
self.append(instance)
self.index_field("word_seq", vocabs["word_vocab"])


def change_field_is_target(data_set, field_name, new_target):
"""Change the flag of is_target in a field.

:param data_set: a DataSet object
:param field_name: str, the name of the field
:param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither.

"""
for inst in data_set:
inst.fields[field_name].is_target = new_target


if __name__ == "__main__":
data_set = SeqLabelDataSet()
data_set.load("../../test/data_for_tests/people.txt")
a, b = data_set.split(0.3)
print(type(data_set), type(a), type(b))
print(len(data_set), len(a), len(b))

+ 6
- 2
fastNLP/core/field.py View File

@@ -59,6 +59,9 @@ class TextField(Field):


class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.

"""
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
@@ -73,13 +76,14 @@ class LabelField(Field):

def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
if isinstance(self.label, str):
self._index = vocab[self.label]
return self._index

def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
return torch.LongTensor([self.label])
return torch.tensor(self.label)
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:


+ 5
- 2
fastNLP/core/instance.py View File

@@ -46,8 +46,11 @@ class Instance(object):
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
if field.is_target:
if field.is_target is True:
tensor_y[name] = field.to_tensor(padding_length[name])
else:
elif field.is_target is False:
tensor_x[name] = field.to_tensor(padding_length[name])
else:
# is_target is None
continue
return tensor_x, tensor_y

+ 12
- 1
fastNLP/core/loss.py View File

@@ -39,8 +39,19 @@ class Loss(object):

:return loss: a PyTorch loss
"""

class InnerCrossEntropy:
"""A simple wrapper to guarantee input shapes."""

def __init__(self):
self.f = torch.nn.CrossEntropyLoss()

def __call__(self, predict, truth):
truth = truth.view(-1, )
return self.f(predict, truth)

if loss_name == "cross_entropy":
return torch.nn.CrossEntropyLoss()
return InnerCrossEntropy()
elif loss_name == 'nll':
return torch.nn.NLLLoss()
else:


+ 45
- 0
fastNLP/core/metrics.py View File

@@ -4,6 +4,51 @@ import numpy as np
import torch


class Evaluator(object):
def __init__(self):
pass

def __call__(self, predict, truth):
"""

:param predict: list of tensors, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return:
"""
raise NotImplementedError


class ClassifyEvaluator(Evaluator):
def __init__(self):
super(ClassifyEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(truth, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return {"accuracy": acc}


class SeqLabelEvaluator(Evaluator):
def __init__(self):
super(SeqLabelEvaluator, self).__init__()

def __call__(self, predict, truth):
"""

:param predict: list of tensors, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
truth = [item["truth"] for item in truth]
truth = torch.cat(truth).view(-1, )
results = torch.Tensor(predict).view(-1, )
accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0]
return {"accuracy": float(accuracy)}


def _conver_numpy(x):
"""convert input data to numpy array



+ 29
- 43
fastNLP/core/predictor.py View File

@@ -16,18 +16,18 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""

def __init__(self, pickle_path, task):
def __init__(self, pickle_path, post_processor):
"""

:param pickle_path: str, the path to the pickle files.
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").
:param post_processor: a function or callable object, that takes list of batch outputs as input

"""
self.batch_size = 1
self.batch_output = []
self.pickle_path = pickle_path
self._task = task # one of ("seq_label", "text_classify")
self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
self._post_processor = post_processor
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl")
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")

def predict(self, network, data):
@@ -38,21 +38,20 @@ class Predictor(object):
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into DataSet object
data = self.prepare_input(data)
# data = self.prepare_input(data)

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.batch_output.clear()
batch_output = []

data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)

for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
batch_output.append(prediction)

self.batch_output.append(prediction)

return self.prepare_output(self.batch_output)
return self._post_processor(batch_output, self.label_vocab)

def mode(self, network, test=True):
if test:
@@ -62,13 +61,7 @@ class Predictor(object):

def data_forward(self, network, x):
"""Forward through network."""
if self._task == "seq_label":
y = network(x["word_seq"], x["word_seq_origin_len"])
y = network.prediction(y)
elif self._task == "text_classify":
y = network(x["word_seq"])
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))
y = network(**x)
return y

def prepare_input(self, data):
@@ -88,39 +81,32 @@ class Predictor(object):
assert isinstance(data, list)
return create_dataset_from_lists(data, self.word_vocab, has_target=False)

def prepare_output(self, data):
"""Transform list of batch outputs into strings."""
if self._task == "seq_label":
return self._seq_label_prepare_output(data)
elif self._task == "text_classify":
return self._text_classify_prepare_output(data)
else:
raise NotImplementedError("Unknown task type {}".format(self._task))

def _seq_label_prepare_output(self, batch_outputs):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([self.label_vocab.to_word(int(x)) for x in example])
return results

def _text_classify_prepare_output(self, batch_outputs):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([self.label_vocab.to_word(i) for i in idx])
return results


class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor)


class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor)


def seq_label_post_processor(batch_outputs, label_vocab):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([label_vocab.to_word(int(x)) for x in example])
return results


def text_classify_post_processor(batch_outputs, label_vocab):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([label_vocab.to_word(i) for i in idx])
return results

+ 0
- 1
fastNLP/core/preprocess.py View File

@@ -114,7 +114,6 @@ class Preprocessor(object):
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
"""

if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")


+ 20
- 111
fastNLP/core/tester.py View File

@@ -1,7 +1,7 @@
import numpy as np
import torch

from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler
from fastNLP.saver.logger import create_logger

@@ -22,28 +22,23 @@ class Tester(object):
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"save_output": True, # collect outputs of validation set
"save_loss": True, # collect losses in validation
"save_best_dev": False, # save best model during validation
"batch_size": 8,
default_args = {"batch_size": 8,
"use_cuda": False,
"pickle_path": "./save/",
"model_name": "dev_best_model.pkl",
"print_every_step": 1,
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {"task" # one of ("seq_label", "text_classify")
}
required_args = {}

for req_key in required_args:
if req_key not in kwargs:
logger.error("Tester lacks argument {}".format(req_key))
raise ValueError("Tester lacks argument {}".format(req_key))
self._task = kwargs["task"]

for key in default_args:
if key in kwargs:
@@ -59,17 +54,13 @@ class Tester(object):
pass
print(default_args)

self.save_output = default_args["save_output"]
self.save_best_dev = default_args["save_best_dev"]
self.save_loss = default_args["save_loss"]
self.batch_size = default_args["batch_size"]
self.pickle_path = default_args["pickle_path"]
self.use_cuda = default_args["use_cuda"]
self.print_every_step = default_args["print_every_step"]
self._evaluator = default_args["evaluator"]

self._model = None
self.eval_history = [] # evaluation results of all batches
self.batch_output = [] # outputs of all batches

def test(self, network, dev_data):
if torch.cuda.is_available() and self.use_cuda:
@@ -80,26 +71,18 @@ class Tester(object):
# turn on the testing mode; clean up the history
self.mode(network, is_test=True)
self.eval_history.clear()
self.batch_output.clear()
output_list = []
truth_list = []

data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
step = 0

for batch_x, batch_y in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)

print_output = "[test step {}] {}".format(step, eval_results)
logger.info(print_output)
if self.print_every_step > 0 and step % self.print_every_step == 0:
print(self.make_eval_output(prediction, eval_results))
step += 1
output_list.append(prediction)
truth_list.append(batch_y)
eval_results = self.evaluate(output_list, truth_list)
print("[tester] {}".format(self.print_eval_results(eval_results)))

def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -121,104 +104,30 @@ class Tester(object):
def evaluate(self, predict, truth):
"""Compute evaluation metrics.

:param predict: Tensor
:param truth: Tensor
:param predict: list of Tensor
:param truth: list of dict
:return eval_results: can be anything. It will be stored in self.eval_history
"""
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
return self._evaluator(predict, truth)

if self._task == "seq_label":
return self._seq_label_evaluate(predict, truth)
elif self._task == "text_classify":
return self._text_classify_evaluate(predict, truth)
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

def _seq_label_evaluate(self, predict, truth):
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth) / batch_size
prediction = self._model.prediction(predict)
# pad prediction to equal length
for pred in prediction:
if len(pred) < max_len:
pred += [0] * (max_len - len(pred))
results = torch.Tensor(prediction).view(-1, )

# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]

def _text_classify_evaluate(self, y_logit, y_true):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

@property
def metrics(self):
"""Compute and return metrics.
Use self.eval_history to compute metrics over the whole dev set.
Please refer to metrics.py for common metric functions.

:return : variable number of outputs
"""
if self._task == "seq_label":
return self._seq_label_metrics
elif self._task == "text_classify":
return self._text_classify_metrics
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

@property
def _seq_label_metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy

@property
def _text_classify_metrics(self):
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def show_metrics(self):
"""Customize evaluation outputs in Trainer.
Called by Trainer to print evaluation results on dev set during training.
Use self.metrics to fetch available metrics.

:return print_str: str
"""
loss, accuracy = self.metrics
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
def print_eval_results(self, results):
"""Override this method to support more print formats.

def make_eval_output(self, predictions, eval_results):
"""Customize Tester outputs.
:param results: dict, (str: float) is (metrics name: value)

:param predictions: Tensor
:param eval_results: Tensor
:return: str, to be printed.
"""
return self.show_metrics()
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])


class SeqLabelTester(Tester):
def __init__(self, **test_args):
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.")
super(SeqLabelTester, self).__init__(**test_args)


class ClassificationTester(Tester):
def __init__(self, **test_args):
test_args.update({"task": "text_classify"})
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
super(ClassificationTester, self).__init__(**test_args)

+ 16
- 31
fastNLP/core/trainer.py View File

@@ -8,6 +8,7 @@ from tensorboardX import SummaryWriter

from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
@@ -43,21 +44,20 @@ class Trainer(object):
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0)
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {"task" # one of ("seq_label", "text_classify")
}
required_args = {}

for req_key in required_args:
if req_key not in kwargs:
logger.error("Trainer lacks argument {}".format(req_key))
raise ValueError("Trainer lacks argument {}".format(req_key))
self._task = kwargs["task"]

for key in default_args:
if key in kwargs:
@@ -86,6 +86,7 @@ class Trainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._evaluator = default_args["evaluator"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
self._best_accuracy = 0.0
@@ -106,9 +107,8 @@ class Trainer(object):

# define Tester over dev data
if self.validate:
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "print_every_step": 0}
default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))

@@ -229,18 +229,9 @@ class Trainer(object):
self._optimizer.step()

def data_forward(self, network, x):
if self._task == "seq_label":
y = network(x["word_seq"], x["word_seq_origin_len"])
elif self._task == "text_classify" or self._task == "language_model":
y = network(x["word_seq"])
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

y = network(**x)
if not self._graph_summaried:
if self._task == "seq_label":
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
elif self._task == "text_classify" or self._task == "language_model":
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
# self._summary_writer.add_graph(network, x, verbose=False)
self._graph_summaried = True
return y

@@ -261,13 +252,9 @@ class Trainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
truth = truth.view((-1,))
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None
return self._loss_func(predict, truth)

def define_loss(self):
@@ -278,8 +265,8 @@ class Trainer(object):
These two losses cannot be defined at the same time.
Trainer does not handle loss definition or choose default losses.
"""
if hasattr(self._model, "loss") and self._loss_func is not None:
raise ValueError("Both the model and Trainer define loss. Please take out your loss.")
# if hasattr(self._model, "loss") and self._loss_func is not None:
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.")

if hasattr(self._model, "loss"):
self._loss_func = self._model.loss
@@ -322,9 +309,8 @@ class SeqLabelTrainer(Trainer):

"""
def __init__(self, **kwargs):
kwargs.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.")
super(SeqLabelTrainer, self).__init__(**kwargs)

def _create_validator(self, valid_args):
@@ -335,9 +321,8 @@ class ClassificationTrainer(Trainer):
"""Trainer for text classification."""

def __init__(self, **train_args):
train_args.update({"task": "text_classify"})
print(
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.")
super(ClassificationTrainer, self).__init__(**train_args)

def _create_validator(self, valid_args):


+ 1
- 1
fastNLP/core/vocabulary.py View File

@@ -54,7 +54,7 @@ class Vocabulary(object):
def update(self, word):
"""add word or list of words into Vocabulary
:param word: a list of str or str
:param word: a list of string or a single string
"""
if not isinstance(word, str) and isiterable(word):
# it's a nested list


+ 8
- 13
fastNLP/loader/base_loader.py View File

@@ -1,23 +1,18 @@
class BaseLoader(object):
"""docstring for BaseLoader"""

def __init__(self, data_path):
def __init__(self):
super(BaseLoader, self).__init__()
self.data_path = data_path

def load(self):
"""
:return: string
"""
with open(self.data_path, "r", encoding="utf-8") as f:
text = f.read()
return text

def load_lines(self):
with open(self.data_path, "r", encoding="utf=8") as f:
def load_lines(self, data_path):
with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines()
return [line.strip() for line in text]

def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
text = f.readlines()
return [[word for word in sent.strip()] for sent in text]


class ToyLoader0(BaseLoader):
"""


+ 3
- 3
fastNLP/loader/config_loader.py View File

@@ -8,9 +8,9 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __int__(self, data_name, data_path):
super(ConfigLoader, self).__init__(data_path)
self.config = self.parse(super(ConfigLoader, self).load())
def __int__(self, data_path):
super(ConfigLoader, self).__init__()
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):


+ 41
- 42
fastNLP/loader/dataset_loader.py View File

@@ -3,14 +3,17 @@ import os
from fastNLP.loader.base_loader import BaseLoader


class DatasetLoader(BaseLoader):
class DataSetLoader(BaseLoader):
""""loader for data sets"""

def __init__(self, data_path):
super(DatasetLoader, self).__init__(data_path)
def __init__(self):
super(DataSetLoader, self).__init__()

def load(self, path):
raise NotImplementedError

class POSDatasetLoader(DatasetLoader):

class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.

In these datasets, each line are divided by '\t'
@@ -31,16 +34,10 @@ class POSDatasetLoader(DatasetLoader):
to label5.
"""

def __init__(self, data_path):
super(POSDatasetLoader, self).__init__(data_path)

def load(self):
assert os.path.exists(self.data_path)
with open(self.data_path, "r", encoding="utf-8") as f:
line = f.read()
return line
def __init__(self):
super(POSDataSetLoader, self).__init__()

def load_lines(self):
def load(self, data_path):
"""
:return data: three-level list
[
@@ -49,7 +46,7 @@ class POSDatasetLoader(DatasetLoader):
...
]
"""
with open(self.data_path, "r", encoding="utf-8") as f:
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return self.parse(lines)

@@ -79,15 +76,15 @@ class POSDatasetLoader(DatasetLoader):
return data


class TokenizeDatasetLoader(DatasetLoader):
class TokenizeDataSetLoader(DataSetLoader):
"""
Data set loader for tokenization data sets
"""

def __init__(self, data_path):
super(TokenizeDatasetLoader, self).__init__(data_path)
def __init__(self):
super(TokenizeDataSetLoader, self).__init__()

def load_pku(self, max_seq_len=32):
def load(self, data_path, max_seq_len=32):
"""
load pku dataset for Chinese word segmentation
CWS (Chinese Word Segmentation) pku training dataset format:
@@ -104,7 +101,7 @@ class TokenizeDatasetLoader(DatasetLoader):
:return: three-level lists
"""
assert isinstance(max_seq_len, int) and max_seq_len > 0
with open(self.data_path, "r", encoding="utf-8") as f:
with open(data_path, "r", encoding="utf-8") as f:
sentences = f.readlines()
data = []
for sent in sentences:
@@ -135,15 +132,15 @@ class TokenizeDatasetLoader(DatasetLoader):
return data


class ClassDatasetLoader(DatasetLoader):
class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""

def __init__(self, data_path):
super(ClassDatasetLoader, self).__init__(data_path)
def __init__(self):
super(ClassDataSetLoader, self).__init__()

def load(self):
assert os.path.exists(self.data_path)
with open(self.data_path, "r", encoding="utf-8") as f:
def load(self, data_path):
assert os.path.exists(data_path)
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return self.parse(lines)

@@ -169,21 +166,21 @@ class ClassDatasetLoader(DatasetLoader):
return dataset


class ConllLoader(DatasetLoader):
class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __int__(self, data_path):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__(data_path)
self.data_set = self.parse(self.load())
super(ConllLoader, self).__init__()
self.data_set = self.parse(self.load(data_path))

def load(self):
def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(self.data_path, "r", encoding="utf-8") as f:
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return lines

@@ -207,20 +204,21 @@ class ConllLoader(DatasetLoader):
return sentences


class LMDatasetLoader(DatasetLoader):
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""
def __init__(self, data_path):
super(LMDatasetLoader, self).__init__(data_path)

def load(self):
if not os.path.exists(self.data_path):
raise FileNotFoundError("file {} not found.".format(self.data_path))
with open(self.data_path, "r", encoding="utf=8") as f:
def __init__(self):
super(LMDataSetLoader, self).__init__()

def load(self, data_path):
if not os.path.exists(data_path):
raise FileNotFoundError("file {} not found.".format(data_path))
with open(data_path, "r", encoding="utf=8") as f:
text = " ".join(f.readlines())
tokens = text.strip().split()
return self.sentence_cut(tokens)
@@ -237,16 +235,17 @@ class LMDatasetLoader(DatasetLoader):
data_set.append([x, y])
return data_set

class PeopleDailyCorpusLoader(DatasetLoader):

class PeopleDailyCorpusLoader(DataSetLoader):
"""
People Daily Corpus: Chinese word segmentation, POS tag, NER
"""

def __init__(self, data_path):
super(PeopleDailyCorpusLoader, self).__init__(data_path)
def __init__(self):
super(PeopleDailyCorpusLoader, self).__init__()

def load(self):
with open(self.data_path, "r", encoding="utf-8") as f:
def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()

pos_tag_examples = []


+ 27
- 7
fastNLP/models/sequence_modeling.py View File

@@ -36,11 +36,13 @@ class SeqLabeling(BaseModel):
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None

def forward(self, word_seq, word_seq_origin_len):
def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
:return y: [batch_size, mex_len, tag_size]
:param truth: LongTensor, [batch_size, max_len]
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
self.mask = self.make_mask(word_seq, word_seq_origin_len)

@@ -50,9 +52,16 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
return x
if truth is not None:
return self._internal_loss(x, truth)
else:
return self.decode(x)

def loss(self, x, y):
""" Since the loss has been computed in forward(), this function simply returns x."""
return x

def _internal_loss(self, x, y):
"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
@@ -74,12 +83,19 @@ class SeqLabeling(BaseModel):
mask = mask.to(x)
return mask

def prediction(self, x):
def decode(self, x, pad=True):
"""
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param pad: pad the output sequence to equal lengths
:return prediction: list of [decode path(list)]
"""
max_len = x.shape[1]
tag_seq = self.Crf.viterbi_decode(x, self.mask)
# pad prediction to equal length
if pad is True:
for pred in tag_seq:
if len(pred) < max_len:
pred += [0] * (max_len - len(pred))
return tag_seq


@@ -106,11 +122,12 @@ class AdvSeqLabel(SeqLabeling):

self.Crf = decoder.CRF.ConditionalRandomField(num_classes)

def forward(self, word_seq, word_seq_origin_len):
def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
:return y: [batch_size, mex_len, tag_size]
:param truth: LongTensor, [batch_size, max_len]
:return y:
"""
self.mask = self.make_mask(word_seq, word_seq_origin_len)

@@ -129,4 +146,7 @@ class AdvSeqLabel(SeqLabeling):
x = self.Linear2(x)
x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
return x
if truth is not None:
return self._internal_loss(x, truth)
else:
return self.decode(x)

+ 6
- 7
fastNLP/modules/aggregator/self_attention.py View File

@@ -55,14 +55,13 @@ class SelfAttention(nn.Module):
input = input.contiguous()
size = input.size() # [bsz, len, nhid]


input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]

y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1,
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]

attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention,2) # [baz ,hop, len]
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]

attention = F.softmax(attention, 2) # [baz ,hop, len]
return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]

+ 2
- 2
reproduction/Char-aware_NLM/main.py View File

@@ -1,14 +1,14 @@
from fastNLP.core.loss import Loss
from fastNLP.core.preprocess import Preprocessor
from fastNLP.core.trainer import Trainer
from fastNLP.loader.dataset_loader import LMDatasetLoader
from fastNLP.loader.dataset_loader import LMDataSetLoader
from fastNLP.models.char_language_model import CharLM

PICKLE = "./save/"


def train():
loader = LMDatasetLoader("./train.txt")
loader = LMDataSetLoader()
train_data = loader.load()

pre = Preprocessor(label_is_seq=True, share_vocab=True)


+ 1
- 1
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -4,7 +4,7 @@ from fastNLP.core.preprocess import ClassPreprocess as Preprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader
from fastNLP.loader.config_loader import ConfigSection
from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention
from fastNLP.modules.decoder.MLP import MLP


+ 3
- 3
reproduction/chinese_word_segment/run.py View File

@@ -5,7 +5,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
@@ -66,8 +66,8 @@ def train():
ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args})

# Data Loader
loader = TokenizeDatasetLoader(cws_data_path)
train_data = loader.load_pku()
loader = TokenizeDataSetLoader()
train_data = loader.load()

# Preprocessor
preprocessor = SeqLabelPreprocess()


+ 1
- 1
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -66,7 +66,7 @@ def train():
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args})

# Data Loader
loader = PeopleDailyCorpusLoader(pos_tag_data_path)
loader = PeopleDailyCorpusLoader()
train_data, _ = loader.load()

# Preprocessor


+ 10
- 10
test/loader/test_dataset_loader.py View File

@@ -1,6 +1,6 @@
import unittest

from fastNLP.loader.dataset_loader import POSDatasetLoader, LMDatasetLoader, TokenizeDatasetLoader, \
from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
PeopleDailyCorpusLoader, ConllLoader


@@ -8,29 +8,29 @@ class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):
data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF"""
lines = data.split("\n")
answer = POSDatasetLoader.parse(lines)
answer = POSDataSetLoader.parse(lines)
truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]]
self.assertListEqual(answer, truth, "POS Dataset Loader")

def test_case_TokenizeDatasetLoader(self):
loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
data = loader.load_pku(max_seq_len=32)
print("pass TokenizeDatasetLoader test!")
loader = TokenizeDataSetLoader()
data = loader.load("test/data_for_tests/", max_seq_len=32)
print("pass TokenizeDataSetLoader test!")

def test_case_POSDatasetLoader(self):
loader = POSDatasetLoader("./test/data_for_tests/people.txt")
loader = POSDataSetLoader()
data = loader.load()
datas = loader.load_lines()
print("pass POSDatasetLoader test!")
print("pass POSDataSetLoader test!")

def test_case_LMDatasetLoader(self):
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
loader = LMDataSetLoader()
data = loader.load()
datas = loader.load_lines()
print("pass TokenizeDatasetLoader test!")
print("pass TokenizeDataSetLoader test!")

def test_PeopleDailyCorpusLoader(self):
loader = PeopleDailyCorpusLoader("./test/data_for_tests/people_daily_raw.txt")
loader = PeopleDailyCorpusLoader()
_, _ = loader.load()

def test_ConllLoader(self):


+ 31
- 36
test/model/seq_labeling.py View File

@@ -4,14 +4,16 @@ sys.path.append("..")
import argparse
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.loader.dataset_loader import BaseLoader
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.preprocess import save_pickle, load_pickle

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
@@ -33,24 +35,27 @@ data_infer_path = args.infer
def infer():
# Load infer configuration, the same as test
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args})
ConfigLoader().load_config(config_dir, {"POS_infer": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
test_args["num_classes"] = len(index2label)
word_vocab = load_pickle(pickle_path, "word2id.pkl")
label_vocab = load_pickle(pickle_path, "label2id.pkl")
test_args["vocab_size"] = len(word_vocab)
test_args["num_classes"] = len(label_vocab)
print("vocabularies loaded")

# Define the same model
model = SeqLabeling(test_args)
print("model defined")

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
print("model loaded!")

# Data Loader
raw_data_loader = BaseLoader(data_infer_path)
infer_data = raw_data_loader.load_lines()
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
print("data set prepared")

# Inference interface
infer = SeqLabelInfer(pickle_path)
@@ -65,24 +70,18 @@ def train_and_test():
# Config Loader
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {
ConfigLoader().load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

# Data Loader
pos_loader = POSDatasetLoader(data_path)
train_data = pos_loader.load_lines()

# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = p.vocab_size
model_args["num_classes"] = p.num_classes
data_set = SeqLabelDataSet()
data_set.load(data_path)
train_set, dev_set = data_set.split(0.3, shuffle=True)
model_args["vocab_size"] = len(data_set.word_vocab)
model_args["num_classes"] = len(data_set.label_vocab)

# Trainer: two definition styles
# 1
# trainer = SeqLabelTrainer(trainer_args.data)
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")

# 2
trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
@@ -98,7 +97,7 @@ def train_and_test():
model = SeqLabeling(model_args)

# Start training
trainer.train(model, data_train, data_dev)
trainer.train(model, train_set, dev_set)
print("Training finished!")

# Saver
@@ -106,7 +105,9 @@ def train_and_test():
saver.save_pytorch(model)
print("Model saved!")

del model, trainer, pos_loader
del model, trainer

change_field_is_target(dev_set, "truth", True)

# Define the same model
model = SeqLabeling(model_args)
@@ -117,27 +118,21 @@ def train_and_test():

# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})

# Tester
tester = SeqLabelTester(save_output=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
tester = SeqLabelTester(batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
evaluator=SeqLabelEvaluator()
)

# Start testing with validation data
tester.test(model, data_dev)

# print test results
print(tester.show_metrics())
tester.test(model, dev_set)
print("model tested!")


if __name__ == "__main__":
train_and_test()
# infer()
infer()

+ 24
- 25
test/model/test_cws.py View File

@@ -1,11 +1,13 @@
import os

from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import Preprocessor, load_pickle
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.preprocess import save_pickle, load_pickle
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver
@@ -19,12 +21,12 @@ config_path = "test/data_for_tests/config"
def infer():
# Load infer configuration, the same as test
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
ConfigLoader().load_config(config_path, {"POS_infer": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
index2label = load_pickle(pickle_path, "label2id.pkl")
test_args["num_classes"] = len(index2label)

# Define the same model
@@ -34,31 +36,29 @@ def infer():
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print("model loaded!")

# Data Loader
raw_data_loader = BaseLoader(data_infer_path)
infer_data = raw_data_loader.load_lines()
# Load infer data
infer_data = SeqLabelDataSet(loader=BaseLoader())
infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)

# Inference interface
infer = Predictor(pickle_path, "seq_label")
# inference
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)

print(results)


def train_test():
# Config Loader
train_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args})
ConfigLoader().load_config(config_path, {"POS_infer": train_args})

# Data Loader
loader = TokenizeDatasetLoader(cws_data_path)
train_data = loader.load_pku()
# define dataset
data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
data_train.load(cws_data_path)
train_args["vocab_size"] = len(data_train.word_vocab)
train_args["num_classes"] = len(data_train.label_vocab)

# Preprocessor
p = Preprocessor(label_is_seq=True)
data_train = p.run(train_data, pickle_path=pickle_path)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes
save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl")

# Trainer
trainer = SeqLabelTrainer(**train_args.data)
@@ -73,7 +73,7 @@ def train_test():
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)

del model, trainer, loader
del model, trainer

# Define the same model
model = SeqLabeling(train_args)
@@ -83,17 +83,16 @@ def train_test():

# Load test configuration
test_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
ConfigLoader().load_config(config_path, {"POS_infer": test_args})
test_args["evaluator"] = SeqLabelEvaluator()

# Tester
tester = SeqLabelTester(**test_args.data)

# Start testing
change_field_is_target(data_train, "truth", True)
tester.test(model, data_train)

# print test results
print(tester.show_metrics())


def test():
os.makedirs("save", exist_ok=True)


+ 16
- 21
test/model/test_seq_label.py View File

@@ -1,11 +1,12 @@
import os

from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import SeqLabelPreprocess
from fastNLP.core.preprocess import save_pickle
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver
@@ -21,18 +22,17 @@ def test_training():
# Config Loader
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader("_").load_config(config_dir, {
ConfigLoader().load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

# Data Loader
pos_loader = POSDatasetLoader(data_path)
train_data = pos_loader.load_lines()
data_set = SeqLabelDataSet()
data_set.load(data_path)
data_train, data_dev = data_set.split(0.3, shuffle=True)
model_args["vocab_size"] = len(data_set.word_vocab)
model_args["num_classes"] = len(data_set.label_vocab)

# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = p.vocab_size
model_args["num_classes"] = p.num_classes
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")

trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
@@ -55,7 +55,7 @@ def test_training():
saver = ModelSaver(os.path.join(pickle_path, model_name))
saver.save_pytorch(model)

del model, trainer, pos_loader
del model, trainer

# Define the same model
model = SeqLabeling(model_args)
@@ -65,21 +65,16 @@ def test_training():

# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})

# Tester
tester = SeqLabelTester(save_output=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
tester = SeqLabelTester(batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
evaluator=SeqLabelEvaluator()
)

# Start testing with validation data
change_field_is_target(data_dev, "truth", True)
tester.test(model, data_dev)

loss, accuracy = tester.metrics
assert 0 < accuracy < 1

+ 23
- 34
test/model/text_classify.py View File

@@ -9,13 +9,14 @@ sys.path.append("..")
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.loss import Loss
from fastNLP.core.dataset import TextClassifyDataSet
from fastNLP.core.preprocess import save_pickle, load_pickle

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
@@ -34,21 +35,18 @@ config_dir = args.config
def infer():
# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader(train_data_dir)
data = ds_loader.load()
unlabeled_data = [x[0] for x in data]
word_vocab = load_pickle(save_dir, "word2id.pkl")
label_vocab = load_pickle(save_dir, "label2id.pkl")
print("vocabulary size:", len(word_vocab))
print("number of classes:", len(label_vocab))

# pre-process data
pre = ClassPreprocess()
data = pre.run(data, pickle_path=save_dir)
print("vocabulary size:", pre.vocab_size)
print("number of classes:", pre.num_classes)
infer_data = TextClassifyDataSet(loader=ClassDataSetLoader())
infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})

model_args = ConfigSection()
# TODO: load from config file
model_args["vocab_size"] = pre.vocab_size
model_args["num_classes"] = pre.num_classes
# ConfigLoader.load_config(config_dir, {"text_class_model": model_args})
model_args["vocab_size"] = len(word_vocab)
model_args["num_classes"] = len(label_vocab)
ConfigLoader.load_config(config_dir, {"text_class_model": model_args})

# construct model
print("Building model...")
@@ -59,7 +57,7 @@ def infer():
print("model loaded!")

infer = ClassificationInfer(pickle_path=save_dir)
results = infer.predict(cnn, unlabeled_data)
results = infer.predict(cnn, infer_data)
print(results)


@@ -69,32 +67,23 @@ def train():

# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader(train_data_dir)
data = ds_loader.load()
print(data[0])
data = TextClassifyDataSet(loader=ClassDataSetLoader())
data.load(train_data_dir)

# pre-process data
pre = ClassPreprocess()
data_train = pre.run(data, pickle_path=save_dir)
print("vocabulary size:", pre.vocab_size)
print("number of classes:", pre.num_classes)
print("vocabulary size:", len(data.word_vocab))
print("number of classes:", len(data.label_vocab))
save_pickle(data.word_vocab, save_dir, "word2id.pkl")
save_pickle(data.label_vocab, save_dir, "label2id.pkl")

model_args["num_classes"] = pre.num_classes
model_args["vocab_size"] = pre.vocab_size
model_args["num_classes"] = len(data.label_vocab)
model_args["vocab_size"] = len(data.word_vocab)

# construct model
print("Building model...")
model = CNNText(model_args)

# ConfigSaver().save_config(config_dir, {"text_class_model": model_args})

# train
print("Training...")

# 1
# trainer = ClassificationTrainer(train_args)

# 2
trainer = ClassificationTrainer(epochs=train_args["epochs"],
batch_size=train_args["batch_size"],
validate=train_args["validate"],
@@ -104,7 +93,7 @@ def train():
model_name=model_name,
loss=Loss("cross_entropy"),
optimizer=Optimizer("SGD", lr=0.001, momentum=0.9))
trainer.train(model, data_train)
trainer.train(model, data)

print("Training finished!")

@@ -115,4 +104,4 @@ def train():

if __name__ == "__main__":
train()
# infer()
infer()

Loading…
Cancel
Save