Browse Source

add Field support in Predictor:

- apply DataSet in Predictor; remove sub-predictors; add "task" argument to specify which task to predict, as how Trainer/Tester did.
- remove Action class
- add helper function for DataSet, to create DataSet easily
- more code comments
- clean up unnecessary codes
- add unit tests for Batch, Predictor, Preprocessor, Trainer, Tester
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
f2fc98b5e6
17 changed files with 489 additions and 315 deletions
  1. +0
    -1
      fastNLP/core/README.md
  2. +1
    -82
      fastNLP/core/action.py
  3. +69
    -0
      fastNLP/core/dataset.py
  4. +4
    -3
      fastNLP/core/field.py
  5. +1
    -1
      fastNLP/core/instance.py
  6. +49
    -128
      fastNLP/core/predictor.py
  7. +8
    -5
      fastNLP/core/tester.py
  8. +13
    -5
      fastNLP/core/trainer.py
  9. +2
    -0
      fastNLP/models/sequence_modeling.py
  10. +0
    -17
      test/core/test_action.py
  11. +62
    -0
      test/core/test_batch.py
  12. +51
    -0
      test/core/test_predictor.py
  13. +49
    -20
      test/core/test_preprocess.py
  14. +48
    -30
      test/core/test_tester.py
  15. +36
    -15
      test/core/test_trainer.py
  16. +0
    -8
      test/model/test_charlm.py
  17. +96
    -0
      test/model/test_seq_label.py

+ 0
- 1
fastNLP/core/README.md View File

@@ -1 +0,0 @@


+ 1
- 82
fastNLP/core/action.py View File

@@ -4,88 +4,6 @@ import numpy as np
import torch


class Action(object):
"""Operations shared by Trainer, Tester, or Inference.

This is designed for reducing replicate codes.
- make_batch: produce a min-batch of data. @staticmethod
- pad: padding method used in sequence modeling. @staticmethod
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
"""

def __init__(self):
super(Action, self).__init__()

@staticmethod
def make_batch(iterator, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.

:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:return :

if output_length is True,
(batch_x, seq_len): tuple of two elements
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

if output_length is False,
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
for batch in iterator:
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]

batch_x = Action.pad(batch_x)
# pad batch_y only if it is a 2-level list
if len(batch_y) > 0 and isinstance(batch_y[0], list):
batch_y = Action.pad(batch_y)

# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
batch_y = convert_to_torch_tensor(batch_y, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

if output_length:
seq_len = [len(x) for x in batch_x]
yield (batch_x, seq_len), batch_y
else:
yield batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.

:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch

@staticmethod
def mode(model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
"""
if is_test:
model.eval()
else:
model.train()


def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors.

@@ -224,6 +142,7 @@ class BucketBatchifier(Batchifier):
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths.
TODO: merge it into Batch
"""

def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):


+ 69
- 0
fastNLP/core/dataset.py View File

@@ -1,8 +1,77 @@
from collections import defaultdict

from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance


def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
if has_target is True:
if label_vocab is None:
raise RuntimeError("Must provide label vocabulary to transform labels.")
return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab)
else:
return create_unlabeled_dataset_from_lists(str_lists, word_vocab)


def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab):
"""Create an DataSet instance that contains labels.

:param str_lists: list of list of strings, [num_examples, 2, *].
::
[
[[word_11, word_12, ...], [label_11, label_12, ...]],
...
]

:param word_vocab: dict of (str: int), which means (word: index).
:param label_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.

"""
data_set = DataSet()
for example in str_lists:
word_seq, label_seq = example[0], example[1]
x = TextField(word_seq, is_target=False)
y = TextField(label_seq, is_target=True)
data_set.append(Instance(word_seq=x, label_seq=y))
data_set.index_field("word_seq", word_vocab)
data_set.index_field("label_seq", label_vocab)
return data_set


def create_unlabeled_dataset_from_lists(str_lists, word_vocab):
"""Create an DataSet instance that contains no labels.

:param str_lists: list of list of strings, [num_examples, *].
::
[
[word_11, word_12, ...],
...
]

:param word_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.

"""
data_set = DataSet()
for word_seq in str_lists:
x = TextField(word_seq, is_target=False)
data_set.append(Instance(word_seq=x))
data_set.index_field("word_seq", word_vocab)
return data_set


class DataSet(list):
"""A DataSet object is a list of Instance objects.

"""
def __init__(self, name="", instances=None):
"""

:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)

"""
list.__init__([])
self.name = name
if instances is not None:


+ 4
- 3
fastNLP/core/field.py View File

@@ -20,9 +20,10 @@ class Field(object):


class TextField(Field):
def __init__(self, text: list, is_target):
def __init__(self, text, is_target):
"""
:param list text:
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(is_target)
self.text = text
@@ -32,7 +33,7 @@ class TextField(Field):
if self._index is None:
self._index = [vocab[c] for c in self.text]
else:
print('error')
raise RuntimeError("Replicate indexing of this field.")
return self._index

def get_length(self):


+ 1
- 1
fastNLP/core/instance.py View File

@@ -41,7 +41,7 @@ class Instance(object):
:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
If is_target is False for all fields, tensor_y would be an empty dict.
"""
tensor_x = {}
tensor_y = {}


+ 49
- 128
fastNLP/core/predictor.py View File

@@ -1,53 +1,10 @@
import numpy as np
import torch

from fastNLP.core.action import Batchifier, SequentialSampler
from fastNLP.core.action import convert_to_torch_tensor
from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils


def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None):
"""Batch and Pad data, only for Inference.

:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None)
:return:
"""
for batch_x in iterator:
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if min_len is not None and batch_x.size(1) < min_len:
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
batch_x = torch.cat((batch_x, pad_tensor), 1)

if output_length:
seq_len = [len(x) for x in batch_x]
yield tuple([batch_x, seq_len])
else:
yield batch_x


def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.

:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
from fastNLP.core.action import SequentialSampler
from fastNLP.core.batch import Batch
from fastNLP.core.dataset import create_dataset_from_lists
from fastNLP.core.preprocess import load_pickle


class Predictor(object):
@@ -59,11 +16,17 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""

def __init__(self, pickle_path):
def __init__(self, pickle_path, task):
"""

:param pickle_path: str, the path to the pickle files.
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").

"""
self.batch_size = 1
self.batch_output = []
self.iterator = None
self.pickle_path = pickle_path
self._task = task # one of ("seq_label", "text_classify")
self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")

@@ -71,19 +34,19 @@ class Predictor(object):
"""Perform inference using the trained model.

:param network: a PyTorch model (cpu)
:param data: list of list of strings
:param data: list of list of strings, [num_examples, seq_len]
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into indices
# transform strings into DataSet object
data = self.prepare_input(data)

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.batch_output.clear()

data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)

for batch_x in self.make_batch(data_iterator, use_cuda=False):
for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)

@@ -99,103 +62,61 @@ class Predictor(object):

def data_forward(self, network, x):
"""Forward through network."""
raise NotImplementedError
def make_batch(self, iterator, use_cuda):
raise NotImplementedError
y = network(**x)
if self._task == "seq_label":
y = network.prediction(y)
return y

def prepare_input(self, data):
"""Transform two-level list of strings into that of index.
"""Transform two-level list of strings into an DataSet object.
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor.

:param data:
:param data: list of list of strings.
::
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
:return data_index: list of list of int.

:return data_set: a DataSet instance.
"""
assert isinstance(data, list)
data_index = []
default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL]
for example in data:
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
return data_index
return create_dataset_from_lists(data, self.word2index, has_target=False)

def prepare_output(self, data):
"""Transform list of batch outputs into strings."""
raise NotImplementedError


class SeqLabelInfer(Predictor):
"""
Inference on sequence labeling models.
"""

def __init__(self, pickle_path):
super(SeqLabelInfer, self).__init__(pickle_path)
if self._task == "seq_label":
return self._seq_label_prepare_output(data)
elif self._task == "text_classify":
return self._text_classify_prepare_output(data)
else:
raise NotImplementedError("Unknown task type {}".format(self._task))

def data_forward(self, network, inputs):
"""
This is only for sequence labeling with CRF decoder.
:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return prediction: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction)

def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=True)

def prepare_output(self, batch_outputs):
"""Transform list of batch outputs into strings.

:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length].
:return results: 2-D list of strings, shape [num_examples, tag_seq_length]
"""
def _seq_label_prepare_output(self, batch_outputs):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example])
return results


class ClassificationInfer(Predictor):
"""
Inference on Classification models.
"""

def __init__(self, pickle_path):
super(ClassificationInfer, self).__init__(pickle_path)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=False, min_len=5)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return results: list of strings
"""
def _text_classify_prepare_output(self, batch_outputs):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([self.index2label[i] for i in idx])
return results


class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")


class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")

+ 8
- 5
fastNLP/core/tester.py View File

@@ -1,7 +1,6 @@
import numpy as np
import torch

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.saver.logger import create_logger
@@ -79,7 +78,7 @@ class BaseTester(object):
self._model = network

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.mode(network, is_test=True)
self.eval_history.clear()
self.batch_output.clear()

@@ -102,13 +101,17 @@ class BaseTester(object):
print(self.make_eval_output(prediction, eval_results))
step += 1

def mode(self, model, test):
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param test: bool, whether in test mode.
:param is_test: bool, whether in test mode or not.

"""
Action.mode(model, test)
if is_test:
model.eval()
else:
model.train()

def data_forward(self, network, x):
"""A forward pass of the model. """


+ 13
- 5
fastNLP/core/trainer.py View File

@@ -6,7 +6,6 @@ from datetime import timedelta
import torch
from tensorboardX import SummaryWriter

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
@@ -126,7 +125,7 @@ class BaseTrainer(object):
logger.info("training epoch {}".format(epoch))

# turn on network training mode
self.mode(network, test=False)
self.mode(network, is_test=False)
# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
@@ -201,8 +200,17 @@ class BaseTrainer(object):
network_copy = copy.deepcopy(network)
self.train(network_copy, train_data_cv[i], dev_data_cv[i])

def mode(self, network, test):
Action.mode(network, test)
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.

"""
if is_test:
model.eval()
else:
model.train()

def define_optimizer(self):
"""Define framework-specific optimizer specified by the models.
@@ -284,7 +292,7 @@ class BaseTrainer(object):
:param validator: a Tester instance
:return: bool, True means current results on dev set is the best.
"""
loss, accuracy = validator.metrics()
loss, accuracy = validator.metrics
if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True


+ 2
- 0
fastNLP/models/sequence_modeling.py View File

@@ -62,6 +62,8 @@ class SeqLabeling(BaseModel):
"""
x = x.float()
y = y.long()
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.Crf(x, y, self.mask)
return torch.mean(total_loss)



+ 0
- 17
test/core/test_action.py View File

@@ -1,17 +0,0 @@
import unittest

from fastNLP.core.action import Action, Batchifier, SequentialSampler


class TestAction(unittest.TestCase):
def test_case_1(self):
x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [1, 1, 1, 1, 2, 2, 2, 2]
data = []
for i in range(len(x)):
data.append([[x[i]], [y[i]]])
data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False)
action = Action()
for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None):
print(batch_x)


+ 62
- 0
test/core/test_batch.py View File

@@ -0,0 +1,62 @@
import unittest

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet, create_dataset_from_lists
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance

raw_texts = ["i am a cat",
"this is a test of new batch",
"ha ha",
"I am a good boy .",
"This is the most beautiful girl ."
]
texts = [text.strip().split() for text in raw_texts]
labels = [0, 1, 0, 0, 1]

# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text:
if tokens not in vocab:
vocab[tokens] = len(vocab)


class TestCase1(unittest.TestCase):
def test(self):
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text, is_target=False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)

# use vocabulary to index data
data.index_field("text", vocab)

# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))

# use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False)
for batch_x, batch_y in data_iterator:
self.assertEqual(len(batch_x), 2)
self.assertTrue(isinstance(batch_x, dict))
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
self.assertTrue(isinstance(batch_y, dict))
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))


class TestCase2(unittest.TestCase):
def test(self):
data = DataSet()
for text in texts:
x = TextField(text, is_target=False)
ins = Instance(text=x)
data.append(ins)
data_set = create_dataset_from_lists(texts, vocab, has_target=False)
self.assertTrue(type(data) == type(data_set))

+ 51
- 0
test/core/test_predictor.py View File

@@ -0,0 +1,51 @@
import os
import unittest

from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle
from fastNLP.models.sequence_modeling import SeqLabeling


class TestPredictor(unittest.TestCase):
def test_seq_label(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}

infer_data = [
['a', 'b', 'c', 'd', 'e'],
['a', '@', 'c', 'd', 'e'],
['a', 'b', '#', 'd', 'e'],
['a', 'b', 'c', '?', 'e'],
['a', 'b', 'c', 'd', '$'],
['!', 'b', 'c', 'd', 'e']
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}

os.system("mkdir save")
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
save_pickle(vocab, "./save/", "word2id.pkl")

model = SeqLabeling(model_args)
predictor = Predictor("./save/", task="seq_label")

results = predictor.predict(network=model, data=infer_data)

self.assertTrue(isinstance(results, list))
self.assertGreater(len(results), 0)
for res in results:
self.assertTrue(isinstance(res, list))
self.assertEqual(len(res), 5)
self.assertTrue(isinstance(res[0], str))

os.system("rm -rf save")
print("pickle path deleted")


class TestPredictor2(unittest.TestCase):
def test_text_classify(self):
# TODO
pass

+ 49
- 20
test/core/test_preprocess.py View File

@@ -1,24 +1,25 @@
import os
import unittest

from fastNLP.core.dataset import DataSet
from fastNLP.core.preprocess import SeqLabelPreprocess

data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]

class TestSeqLabelPreprocess(unittest.TestCase):
def test_case_1(self):
data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]

class TestCase1(unittest.TestCase):
def test(self):
if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
@@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase):
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
self.assertEqual(len(result), 2)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)

os.system("rm -rf save")
print("pickle path deleted")


class TestCase2(unittest.TestCase):
def test(self):
if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
cross_val=False)
self.assertEqual(len(result), 3)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)
self.assertEqual(type(result[2]), DataSet)

os.system("rm -rf save")
print("pickle path deleted")


class TestCase3(unittest.TestCase):
def test(self):
num_folds = 2
result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True, n_fold=num_folds)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]), num_folds)
self.assertEqual(len(result[1]), num_folds)
for data_set in result[0] + result[1]:
self.assertEqual(type(data_set), DataSet)

os.system("rm -rf save")
print("pickle path deleted")

+ 48
- 30
test/core/test_tester.py View File

@@ -1,37 +1,55 @@
from fastNLP.core.preprocess import SeqLabelPreprocess
import os
import unittest

from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
from fastNLP.models.sequence_modeling import SeqLabeling

data_name = "pku_training.utf8"
pickle_path = "data_for_tests"


def foo():
loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
train_data = loader.load_pku()

train_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})

# Preprocessor
p = SeqLabelPreprocess()
train_data = p.run(train_data)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes

model = SeqLabeling(train_args)

valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
"use_cuda": True}
validator = SeqLabelTester(**valid_args)

print("start validation.")
validator.test(model, train_data)
print(validator.show_metrics())


if __name__ == "__main__":
foo()
class TestTester(unittest.TestCase):
def test_case_1(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
"use_cuda": False, "print_every_step": 1}

train_data = [
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)

data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)

model = SeqLabeling(model_args)

tester = SeqLabelTester(**valid_args)
tester.test(network=model, dev_data=data_set)
# If this can run, everything is OK.

os.system("rm -rf save")
print("pickle path deleted")

+ 36
- 15
test/core/test_trainer.py View File

@@ -1,33 +1,54 @@
import os

import torch.nn as nn
import unittest

from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling


class TestTrainer(unittest.TestCase):
def test_case_1(self):
args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 20,
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 3
"num_classes": 5
}
trainer = SeqLabelTrainer()
trainer = SeqLabelTrainer(**args)

train_data = [
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
dev_data = train_data
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)

data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)

model = SeqLabeling(args)
trainer.train(network=model, train_data=train_data, dev_data=dev_data)

trainer.train(network=model, train_data=data_set, dev_data=data_set)
# If this can run, everything is OK.

os.system("rm -rf save")
print("pickle path deleted")

+ 0
- 8
test/model/test_charlm.py View File

@@ -1,8 +0,0 @@


def test_charlm():
pass


if __name__ == "__main__":
test_charlm()

+ 96
- 0
test/model/test_seq_label.py View File

@@ -0,0 +1,96 @@
import argparse
import os

from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import SeqLabelPreprocess
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt",
help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt",
help="data used for inference")

args = parser.parse_args()
pickle_path = args.save
model_name = args.model_name
config_dir = args.config
data_path = args.train
data_infer_path = args.infer


def test_training():
# Config Loader
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

# Data Loader
pos_loader = POSDatasetLoader(data_path)
train_data = pos_loader.load_lines()

# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = p.vocab_size
model_args["num_classes"] = p.num_classes

trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=False,
use_cuda=False,
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
)

# Model
model = SeqLabeling(model_args)

# Start training
trainer.train(model, data_train, data_dev)

# Saver
saver = ModelSaver(os.path.join(pickle_path, model_name))
saver.save_pytorch(model)

del model, trainer, pos_loader

# Define the same model
model = SeqLabeling(model_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))

# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})

# Tester
tester = SeqLabelTester(save_output=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
)

# Start testing with validation data
tester.test(model, data_dev)

loss, accuracy = tester.metrics
assert 0 < accuracy < 1

Loading…
Cancel
Save