Browse Source

Merge pull request #5 from fastnlp/master

update
tags/v0.1.0
lyhuang18 GitHub 6 years ago
parent
commit
f2850766b8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1305 additions and 778 deletions
  1. +0
    -1
      fastNLP/core/README.md
  2. +6
    -115
      fastNLP/core/action.py
  3. +126
    -0
      fastNLP/core/batch.py
  4. +111
    -0
      fastNLP/core/dataset.py
  5. +93
    -0
      fastNLP/core/field.py
  6. +53
    -0
      fastNLP/core/instance.py
  7. +2
    -0
      fastNLP/core/loss.py
  8. +49
    -128
      fastNLP/core/predictor.py
  9. +117
    -162
      fastNLP/core/preprocess.py
  10. +81
    -116
      fastNLP/core/tester.py
  11. +56
    -96
      fastNLP/core/trainer.py
  12. +6
    -2
      fastNLP/models/cnn_text_classification.py
  13. +42
    -14
      fastNLP/models/sequence_modeling.py
  14. +40
    -11
      fastNLP/modules/aggregation/self_attention.py
  15. +4
    -3
      fastNLP/modules/decoder/CRF.py
  16. +3
    -3
      fastNLP/modules/decoder/MLP.py
  17. +7
    -4
      fastNLP/modules/encoder/char_embedding.py
  18. +4
    -2
      fastNLP/modules/encoder/conv.py
  19. +4
    -2
      fastNLP/modules/encoder/conv_maxpool.py
  20. +3
    -3
      fastNLP/modules/encoder/linear.py
  21. +5
    -3
      fastNLP/modules/encoder/lstm.py
  22. +3
    -3
      fastNLP/modules/encoder/masked_rnn.py
  23. +5
    -4
      fastNLP/modules/encoder/variational_rnn.py
  24. +47
    -2
      fastNLP/modules/utils.py
  25. +13
    -0
      reproduction/LSTM+self_attention_sentiment_analysis/config.cfg
  26. +80
    -0
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  27. +4
    -4
      setup.py
  28. +0
    -17
      test/core/test_action.py
  29. +62
    -0
      test/core/test_batch.py
  30. +51
    -0
      test/core/test_predictor.py
  31. +49
    -20
      test/core/test_preprocess.py
  32. +48
    -30
      test/core/test_tester.py
  33. +36
    -15
      test/core/test_trainer.py
  34. +7
    -7
      test/model/seq_labeling.py
  35. +0
    -8
      test/model/test_charlm.py
  36. +85
    -0
      test/model/test_seq_label.py
  37. +3
    -3
      test/model/text_classify.py

+ 0
- 1
fastNLP/core/README.md View File

@@ -1 +0,0 @@


+ 6
- 115
fastNLP/core/action.py View File

@@ -4,88 +4,6 @@ import numpy as np
import torch import torch




class Action(object):
"""Operations shared by Trainer, Tester, or Inference.

This is designed for reducing replicate codes.
- make_batch: produce a min-batch of data. @staticmethod
- pad: padding method used in sequence modeling. @staticmethod
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
"""

def __init__(self):
super(Action, self).__init__()

@staticmethod
def make_batch(iterator, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.

:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:return :

if output_length is True,
(batch_x, seq_len): tuple of two elements
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

if output_length is False,
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
for batch in iterator:
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]

batch_x = Action.pad(batch_x)
# pad batch_y only if it is a 2-level list
if len(batch_y) > 0 and isinstance(batch_y[0], list):
batch_y = Action.pad(batch_y)

# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
batch_y = convert_to_torch_tensor(batch_y, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

if output_length:
seq_len = [len(x) for x in batch_x]
yield (batch_x, seq_len), batch_y
else:
yield batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.

:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch

@staticmethod
def mode(model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
"""
if is_test:
model.eval()
else:
model.train()


def convert_to_torch_tensor(data_list, use_cuda): def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors. """Convert lists into (cuda) Tensors.


@@ -168,19 +86,7 @@ class BaseSampler(object):


""" """


def __init__(self, data_set):
"""

:param data_set: multi-level list, of shape [num_example, *]

"""
self.data_set_length = len(data_set)
self.data = data_set

def __len__(self):
return self.data_set_length

def __iter__(self):
def __call__(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError




@@ -189,16 +95,8 @@ class SequentialSampler(BaseSampler):


""" """


def __init__(self, data_set):
"""

:param data_set: multi-level list

"""
super(SequentialSampler, self).__init__(data_set)

def __iter__(self):
return iter(self.data)
def __call__(self, data_set):
return list(range(len(data_set)))




class RandomSampler(BaseSampler): class RandomSampler(BaseSampler):
@@ -206,17 +104,9 @@ class RandomSampler(BaseSampler):


""" """


def __init__(self, data_set):
"""

:param data_set: multi-level list
def __call__(self, data_set):
return list(np.random.permutation(len(data_set)))


"""
super(RandomSampler, self).__init__(data_set)
self.order = np.random.permutation(self.data_set_length)

def __iter__(self):
return iter((self.data[idx] for idx in self.order))




class Batchifier(object): class Batchifier(object):
@@ -252,6 +142,7 @@ class BucketBatchifier(Batchifier):
"""Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it. In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths. The number of buckets is decided dynamically by the variance of sentence lengths.
TODO: merge it into Batch
""" """


def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):


+ 126
- 0
fastNLP/core/batch.py View File

@@ -0,0 +1,126 @@
from collections import defaultdict

import torch

from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance


class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.

::
for batch_x, batch_y in Batch(data_set):

"""

def __init__(self, dataset, batch_size, sampler, use_cuda):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler
self.use_cuda = use_cuda
self.idx_list = None
self.curidx = 0

def __iter__(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
return self

def __next__(self):
"""

:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
batch_x also contains an item (str: list of int) about origin lengths,
which means ("field_name_origin_len": origin lengths).
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})

batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
The names of fields are defined in preprocessor's convert_to_dataset method.

"""
if self.curidx >= len(self.idx_list):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
origin_lengths = {field_name: field_length[self.curidx: endidx]
for field_name, field_length in self.lengths.items()}

batch_x, batch_y = defaultdict(list), defaultdict(list)
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)

batch_origin_length = {}
# combine instances into a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
batch[name] = torch.stack(tensor_list, dim=0).cuda()
else:
batch[name] = torch.stack(tensor_list, dim=0)

# add origin lengths in batch_x
for name, tensor in batch_x.items():
if self.use_cuda:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
else:
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
batch_x.update(batch_origin_length)

self.curidx += endidx
return batch_x, batch_y


if __name__ == "__main__":
"""simple running example
"""
texts = ["i am a cat",
"this is a test of new batch",
"haha"
]
labels = [0, 1, 0]

# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text.split():
if tokens not in vocab:
vocab[tokens] = len(vocab)
print("vocabulary: ", vocab)

# prepare input dataset
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text.split(), False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)

# use vocabulary to index data
data.index_field("text", vocab)


# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))


# use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False)
for epoch in range(1):
for batch_x, batch_y in data_iterator:
print(batch_x)
print(batch_y)
# do stuff

+ 111
- 0
fastNLP/core/dataset.py View File

@@ -0,0 +1,111 @@
from collections import defaultdict

from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance


def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
if has_target is True:
if label_vocab is None:
raise RuntimeError("Must provide label vocabulary to transform labels.")
return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab)
else:
return create_unlabeled_dataset_from_lists(str_lists, word_vocab)


def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab):
"""Create an DataSet instance that contains labels.

:param str_lists: list of list of strings, [num_examples, 2, *].
::
[
[[word_11, word_12, ...], [label_11, label_12, ...]],
...
]

:param word_vocab: dict of (str: int), which means (word: index).
:param label_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.

"""
data_set = DataSet()
for example in str_lists:
word_seq, label_seq = example[0], example[1]
x = TextField(word_seq, is_target=False)
y = TextField(label_seq, is_target=True)
data_set.append(Instance(word_seq=x, label_seq=y))
data_set.index_field("word_seq", word_vocab)
data_set.index_field("label_seq", label_vocab)
return data_set


def create_unlabeled_dataset_from_lists(str_lists, word_vocab):
"""Create an DataSet instance that contains no labels.

:param str_lists: list of list of strings, [num_examples, *].
::
[
[word_11, word_12, ...],
...
]

:param word_vocab: dict of (str: int), which means (word: index).
:return data_set: a DataSet instance.

"""
data_set = DataSet()
for word_seq in str_lists:
x = TextField(word_seq, is_target=False)
data_set.append(Instance(word_seq=x))
data_set.index_field("word_seq", word_vocab)
return data_set


class DataSet(list):
"""A DataSet object is a list of Instance objects.

"""
def __init__(self, name="", instances=None):
"""

:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)

"""
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)

def index_all(self, vocab):
for ins in self:
ins.index_all(vocab)

def index_field(self, field_name, vocab):
for ins in self:
ins.index_field(field_name, vocab)

def to_tensor(self, idx: int, padding_length: dict):
"""Convert an instance in a dataset to tensor.

:param idx: int, the index of the instance in the dataset.
:param padding_length: int
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])

"""
ins = self[idx]
return ins.to_tensor(padding_length)

def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.

:return lengths: dict of (str: list). The str is the field name.
The list contains lengths of this field in all instances.

"""
lengths = defaultdict(list)
for ins in self:
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths

+ 93
- 0
fastNLP/core/field.py View File

@@ -0,0 +1,93 @@
import torch


class Field(object):
"""A field defines a data type.

"""

def __init__(self, is_target: bool):
self.is_target = is_target

def index(self, vocab):
raise NotImplementedError

def get_length(self):
raise NotImplementedError

def to_tensor(self, padding_length):
raise NotImplementedError


class TextField(Field):
def __init__(self, text, is_target):
"""
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(is_target)
self.text = text
self._index = None

def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.text]
else:
raise RuntimeError("Replicate indexing of this field.")
return self._index

def get_length(self):
"""Fetch the length of the text field.

:return length: int, the length of the text.

"""
return len(self.text)

def to_tensor(self, padding_length: int):
"""Convert text field to tensor.

:param padding_length: int
:return tensor: torch.LongTensor, of shape [padding_length, ]
"""
pads = []
if self._index is None:
raise RuntimeError("Indexing not done before to_tensor in TextField.")
if padding_length > self.get_length():
pads = [0] * (padding_length - self.get_length())
return torch.LongTensor(self._index + pads)


class LabelField(Field):
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
self._index = None

def get_length(self):
"""Fetch the length of the label field.

:return length: int, the length of the label, always 1.
"""
return 1

def index(self, vocab):
if self._index is None:
self._index = vocab[self.label]
return self._index

def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
return torch.LongTensor([self.label])
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor([self._index])


if __name__ == "__main__":
tf = TextField("test the code".split(), is_target=False)

+ 53
- 0
fastNLP/core/instance.py View File

@@ -0,0 +1,53 @@
class Instance(object):
"""An instance which consists of Fields is an example in the DataSet.

"""

def __init__(self, **fields):
self.fields = fields
self.has_index = False
self.indexes = {}

def add_field(self, field_name, field):
self.fields[field_name] = field

def get_length(self):
"""Fetch the length of all fields in the instance.

:return length: dict of (str: int), which means (field name: field length).

"""
length = {name: field.get_length() for name, field in self.fields.items()}
return length

def index_field(self, field_name, vocab):
"""use `vocab` to index certain field
"""
self.indexes[field_name] = self.fields[field_name].index(vocab)

def index_all(self, vocab):
"""use `vocab` to index all fields
"""
if self.has_index:
print("error")
return self.indexes
indexes = {name: field.index(vocab) for name, field in self.fields.items()}
self.indexes = indexes
return indexes

def to_tensor(self, padding_length: dict):
"""Convert instance to tensor.

:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
If is_target is False for all fields, tensor_y would be an empty dict.
"""
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
if field.is_target:
tensor_y[name] = field.to_tensor(padding_length[name])
else:
tensor_x[name] = field.to_tensor(padding_length[name])
return tensor_x, tensor_y

+ 2
- 0
fastNLP/core/loss.py View File

@@ -37,5 +37,7 @@ class Loss(object):
""" """
if loss_name == "cross_entropy": if loss_name == "cross_entropy":
return torch.nn.CrossEntropyLoss() return torch.nn.CrossEntropyLoss()
elif loss_name == 'nll':
return torch.nn.NLLLoss()
else: else:
raise NotImplementedError raise NotImplementedError

+ 49
- 128
fastNLP/core/predictor.py View File

@@ -1,53 +1,10 @@
import numpy as np import numpy as np
import torch import torch


from fastNLP.core.action import Batchifier, SequentialSampler
from fastNLP.core.action import convert_to_torch_tensor
from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils


def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None):
"""Batch and Pad data, only for Inference.

:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None)
:return:
"""
for batch_x in iterator:
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if min_len is not None and batch_x.size(1) < min_len:
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
batch_x = torch.cat((batch_x, pad_tensor), 1)

if output_length:
seq_len = [len(x) for x in batch_x]
yield tuple([batch_x, seq_len])
else:
yield batch_x


def pad(batch, fill=0):
""" Pad a mini-batch of sequence samples to maximum length of this batch.

:param batch: list of list
:param fill: word index to pad, default 0.
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
from fastNLP.core.action import SequentialSampler
from fastNLP.core.batch import Batch
from fastNLP.core.dataset import create_dataset_from_lists
from fastNLP.core.preprocess import load_pickle




class Predictor(object): class Predictor(object):
@@ -59,11 +16,17 @@ class Predictor(object):
Currently, Predictor does not support GPU. Currently, Predictor does not support GPU.
""" """


def __init__(self, pickle_path):
def __init__(self, pickle_path, task):
"""

:param pickle_path: str, the path to the pickle files.
:param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").

"""
self.batch_size = 1 self.batch_size = 1
self.batch_output = [] self.batch_output = []
self.iterator = None
self.pickle_path = pickle_path self.pickle_path = pickle_path
self._task = task # one of ("seq_label", "text_classify")
self.index2label = load_pickle(self.pickle_path, "id2class.pkl") self.index2label = load_pickle(self.pickle_path, "id2class.pkl")
self.word2index = load_pickle(self.pickle_path, "word2id.pkl") self.word2index = load_pickle(self.pickle_path, "word2id.pkl")


@@ -71,19 +34,19 @@ class Predictor(object):
"""Perform inference using the trained model. """Perform inference using the trained model.


:param network: a PyTorch model (cpu) :param network: a PyTorch model (cpu)
:param data: list of list of strings
:param data: list of list of strings, [num_examples, seq_len]
:return: list of list of strings, [num_examples, tag_seq_length] :return: list of list of strings, [num_examples, tag_seq_length]
""" """
# transform strings into indices
# transform strings into DataSet object
data = self.prepare_input(data) data = self.prepare_input(data)


# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
self.mode(network, test=True) self.mode(network, test=True)
self.batch_output.clear() self.batch_output.clear()


data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)


for batch_x in self.make_batch(data_iterator, use_cuda=False):
for batch_x, _ in data_iterator:
with torch.no_grad(): with torch.no_grad():
prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)


@@ -99,103 +62,61 @@ class Predictor(object):


def data_forward(self, network, x): def data_forward(self, network, x):
"""Forward through network.""" """Forward through network."""
raise NotImplementedError
def make_batch(self, iterator, use_cuda):
raise NotImplementedError
y = network(**x)
if self._task == "seq_label":
y = network.prediction(y)
return y


def prepare_input(self, data): def prepare_input(self, data):
"""Transform two-level list of strings into that of index.
"""Transform two-level list of strings into an DataSet object.
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor.


:param data:
:param data: list of list of strings.
::
[ [
[word_11, word_12, ...], [word_11, word_12, ...],
[word_21, word_22, ...], [word_21, word_22, ...],
... ...
] ]
:return data_index: list of list of int.

:return data_set: a DataSet instance.
""" """
assert isinstance(data, list) assert isinstance(data, list)
data_index = []
default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL]
for example in data:
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
return data_index
return create_dataset_from_lists(data, self.word2index, has_target=False)


def prepare_output(self, data): def prepare_output(self, data):
"""Transform list of batch outputs into strings.""" """Transform list of batch outputs into strings."""
raise NotImplementedError


class SeqLabelInfer(Predictor):
"""
Inference on sequence labeling models.
"""

def __init__(self, pickle_path):
super(SeqLabelInfer, self).__init__(pickle_path)
if self._task == "seq_label":
return self._seq_label_prepare_output(data)
elif self._task == "text_classify":
return self._text_classify_prepare_output(data)
else:
raise NotImplementedError("Unknown task type {}".format(self._task))


def data_forward(self, network, inputs):
"""
This is only for sequence labeling with CRF decoder.
:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return prediction: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction)

def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=True)

def prepare_output(self, batch_outputs):
"""Transform list of batch outputs into strings.

:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length].
:return results: 2-D list of strings, shape [num_examples, tag_seq_length]
"""
def _seq_label_prepare_output(self, batch_outputs):
results = [] results = []
for batch in batch_outputs: for batch in batch_outputs:
for example in np.array(batch): for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example]) results.append([self.index2label[int(x)] for x in example])
return results return results



class ClassificationInfer(Predictor):
"""
Inference on Classification models.
"""

def __init__(self, pickle_path):
super(ClassificationInfer, self).__init__(pickle_path)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=False, min_len=5)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return results: list of strings
"""
def _text_classify_prepare_output(self, batch_outputs):
results = [] results = []
for batch_out in batch_outputs: for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1) idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([self.index2label[i] for i in idx]) results.extend([self.index2label[i] for i in idx])
return results return results


class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")


class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
super(ClassificationInfer, self).__init__(pickle_path, "text_classify")

+ 117
- 162
fastNLP/core/preprocess.py View File

@@ -3,6 +3,10 @@ import os


import numpy as np import numpy as np


from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
DEFAULT_RESERVED_LABEL = ['<reserved-2>', DEFAULT_RESERVED_LABEL = ['<reserved-2>',
@@ -84,7 +88,7 @@ class BasePreprocess(object):
return len(self.label2index) return len(self.label2index)


def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
"""Main preprocessing pipeline.
"""Main pre-processing pipeline.


:param train_dev_data: three-level list, with either single label or multiple labels in a sample. :param train_dev_data: three-level list, with either single label or multiple labels in a sample.
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional) :param test_data: three-level list, with either single label or multiple labels in a sample. (optional)
@@ -92,7 +96,9 @@ class BasePreprocess(object):
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set.
:param cross_val: bool, whether to do cross validation. :param cross_val: bool, whether to do cross validation.
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True.
:return results: a tuple of datasets after preprocessing.
:return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset.
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
""" """


if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
@@ -111,68 +117,87 @@ class BasePreprocess(object):
index2label = self.build_reverse_dict(self.label2index) index2label = self.build_reverse_dict(self.label2index)
save_pickle(index2label, pickle_path, "id2class.pkl") save_pickle(index2label, pickle_path, "id2class.pkl")


data_train = []
data_dev = []
train_set = []
dev_set = []
if not cross_val: if not cross_val:
if not pickle_exist(pickle_path, "data_train.pkl"): if not pickle_exist(pickle_path, "data_train.pkl"):
data_train.extend(self.to_index(train_dev_data))
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"):
split = int(len(data_train) * train_dev_split)
data_dev = data_train[: split]
data_train = data_train[split:]
save_pickle(data_dev, pickle_path, "data_dev.pkl")
split = int(len(train_dev_data) * train_dev_split)
data_dev = train_dev_data[: split]
data_train = train_dev_data[split:]
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index)
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index)

save_pickle(dev_set, pickle_path, "data_dev.pkl")
print("{} of the training data is split for validation. ".format(train_dev_split)) print("{} of the training data is split for validation. ".format(train_dev_split))
save_pickle(data_train, pickle_path, "data_train.pkl")
else:
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index)
save_pickle(train_set, pickle_path, "data_train.pkl")
else: else:
data_train = load_pickle(pickle_path, "data_train.pkl")
train_set = load_pickle(pickle_path, "data_train.pkl")
if pickle_exist(pickle_path, "data_dev.pkl"): if pickle_exist(pickle_path, "data_dev.pkl"):
data_dev = load_pickle(pickle_path, "data_dev.pkl")
dev_set = load_pickle(pickle_path, "data_dev.pkl")
else: else:
# cross_val is True # cross_val is True
if not pickle_exist(pickle_path, "data_train_0.pkl"): if not pickle_exist(pickle_path, "data_train_0.pkl"):
# cross validation # cross validation
data_idx = self.to_index(train_dev_data)
data_cv = self.cv_split(data_idx, n_fold)
data_cv = self.cv_split(train_dev_data, n_fold)
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): for i, (data_train_cv, data_dev_cv) in enumerate(data_cv):
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index)
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index)
save_pickle( save_pickle(
data_train_cv, pickle_path, data_train_cv, pickle_path,
"data_train_{}.pkl".format(i)) "data_train_{}.pkl".format(i))
save_pickle( save_pickle(
data_dev_cv, pickle_path, data_dev_cv, pickle_path,
"data_dev_{}.pkl".format(i)) "data_dev_{}.pkl".format(i))
data_train.append(data_train_cv)
data_dev.append(data_dev_cv)
train_set.append(data_train_cv)
dev_set.append(data_dev_cv)
print("{}-fold cross validation.".format(n_fold)) print("{}-fold cross validation.".format(n_fold))
else: else:
for i in range(n_fold): for i in range(n_fold):
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i))
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i))
data_train.append(data_train_cv)
data_dev.append(data_dev_cv)
train_set.append(data_train_cv)
dev_set.append(data_dev_cv)


# prepare test data if provided # prepare test data if provided
data_test = []
test_set = []
if test_data is not None: if test_data is not None:
if not pickle_exist(pickle_path, "data_test.pkl"): if not pickle_exist(pickle_path, "data_test.pkl"):
data_test = self.to_index(test_data)
save_pickle(data_test, pickle_path, "data_test.pkl")
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index)
save_pickle(test_set, pickle_path, "data_test.pkl")


# return preprocessed results # return preprocessed results
results = [data_train]
results = [train_set]
if cross_val or train_dev_split > 0: if cross_val or train_dev_split > 0:
results.append(data_dev)
results.append(dev_set)
if test_data: if test_data:
results.append(data_test)
results.append(test_set)
if len(results) == 1: if len(results) == 1:
return results[0] return results[0]
else: else:
return tuple(results) return tuple(results)


def build_dict(self, data): def build_dict(self, data):
raise NotImplementedError
label2index = DEFAULT_WORD_TO_INDEX.copy()
word2index = DEFAULT_WORD_TO_INDEX.copy()
for example in data:
for word in example[0]:
if word not in word2index:
word2index[word] = len(word2index)
label = example[1]
if isinstance(label, str):
# label is a string
if label not in label2index:
label2index[label] = len(label2index)
elif isinstance(label, list):
# label is a list of strings
for single_label in label:
if single_label not in label2index:
label2index[single_label] = len(label2index)
return word2index, label2index


def to_index(self, data):
raise NotImplementedError


def build_reverse_dict(self, word_dict): def build_reverse_dict(self, word_dict):
id2word = {word_dict[w]: w for w in word_dict} id2word = {word_dict[w]: w for w in word_dict}
@@ -186,11 +211,23 @@ class BasePreprocess(object):
return data_train, data_dev return data_train, data_dev


def cv_split(self, data, n_fold): def cv_split(self, data, n_fold):
"""Split data for cross validation."""
"""Split data for cross validation.

:param data: list of string
:param n_fold: int
:return data_cv:

::
[
(data_train, data_dev), # 1st fold
(data_train, data_dev), # 2nd fold
...
]

"""
data_copy = data.copy() data_copy = data.copy()
np.random.shuffle(data_copy) np.random.shuffle(data_copy)
fold_size = round(len(data_copy) / n_fold) fold_size = round(len(data_copy) / n_fold)

data_cv = [] data_cv = []
for i in range(n_fold - 1): for i in range(n_fold - 1):
start = i * fold_size start = i * fold_size
@@ -202,154 +239,72 @@ class BasePreprocess(object):
data_dev = data_copy[start:] data_dev = data_copy[start:]
data_train = data_copy[:start] data_train = data_copy[:start]
data_cv.append((data_train, data_dev)) data_cv.append((data_train, data_dev))

return data_cv return data_cv


def convert_to_dataset(self, data, vocab, label_vocab):
"""Convert list of indices into a DataSet object.


class SeqLabelPreprocess(BasePreprocess):
"""Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
data of three-level list which have multiple labels in each sample.
::

[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]

"""

def __init__(self):
super(SeqLabelPreprocess, self).__init__()

def build_dict(self, data):
"""Add new words with indices into self.word_dict, new labels with indices into self.label_dict.

:param data: three-level list
::

[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]

:return word2index: dict of {str, int}
label2index: dict of {str, int}
:param data: list. Entries are strings.
:param vocab: a dict, mapping string (token) to index (int).
:param label_vocab: a dict, mapping string (label) to index (int).
:return data_set: a DataSet object
""" """
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch.
label2index = DEFAULT_WORD_TO_INDEX.copy()
word2index = DEFAULT_WORD_TO_INDEX.copy()
for example in data:
for word, label in zip(example[0], example[1]):
if word not in word2index:
word2index[word] = len(word2index)
if label not in label2index:
label2index[label] = len(label2index)
return word2index, label2index

def to_index(self, data):
"""Convert word strings and label strings into indices.

:param data: three-level list
::
use_word_seq = False
use_label_seq = False
use_label_str = False


[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]

:return data_index: the same shape as data, but each string is replaced by its corresponding index
"""
data_index = []
# construct a DataSet object and fill it with Instances
data_set = DataSet()
for example in data: for example in data:
word_list = []
label_list = []
for word, label in zip(example[0], example[1]):
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
data_index.append([word_list, label_list])
return data_index
words, label = example[0], example[1]
instance = Instance()


if isinstance(words, list):
x = TextField(words, is_target=False)
instance.add_field("word_seq", x)
use_word_seq = True
else:
raise NotImplementedError("words is a {}".format(type(words)))

if isinstance(label, list):
y = TextField(label, is_target=True)
instance.add_field("label_seq", y)
use_label_seq = True
elif isinstance(label, str):
y = LabelField(label, is_target=True)
instance.add_field("label", y)
use_label_str = True
else:
raise NotImplementedError("label is a {}".format(type(label)))
data_set.append(instance)


class ClassPreprocess(BasePreprocess):
""" Preprocess pipeline for classification datasets.
Preprocess pipeline, including building mapping from words to index, from index to words,
from labels/classes to index, from index to labels/classes.
design for data of three-level list which has a single label in each sample.
::
# convert strings to indices
if use_word_seq:
data_set.index_field("word_seq", vocab)
if use_label_seq:
data_set.index_field("label_seq", label_vocab)
if use_label_str:
data_set.index_field("label", label_vocab)


[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
return data_set


"""


class SeqLabelPreprocess(BasePreprocess):
def __init__(self): def __init__(self):
super(ClassPreprocess, self).__init__()

def build_dict(self, data):
"""Build vocabulary."""


# build vocabulary from scratch if nothing exists
word2index = DEFAULT_WORD_TO_INDEX.copy()
label2index = DEFAULT_WORD_TO_INDEX.copy()

# collect every word and label
for sent, label in data:
if len(sent) <= 1:
continue

if label not in label2index:
label2index[label] = len(label2index)

for word in sent:
if word not in word2index:
word2index[word] = len(word2index)
return word2index, label2index

def to_index(self, data):
"""Convert word strings and label strings into indices.

:param data: three-level list
::

[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
super(SeqLabelPreprocess, self).__init__()


:return data_index: the same shape as data, but each string is replaced by its corresponding index
"""
data_index = []
for example in data:
word_list = []
# example[0] is the word list, example[1] is the single label
for word in example[0]:
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])
data_index.append([word_list, label_index])
return data_index




def infer_preprocess(pickle_path, data):
"""Preprocess over inference data. Transform three-level list of strings into that of index.
::
class ClassPreprocess(BasePreprocess):
def __init__(self):
super(ClassPreprocess, self).__init__()


[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]


"""
word2index = load_pickle(pickle_path, "word2id.pkl")
data_index = []
for example in data:
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example])
return data_index
if __name__ == "__main__":
p = BasePreprocess()
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
[["You", "are", "pretty", "."], "1"]
]
training_set = p.run(train_dev_data)
print(training_set)

+ 81
- 116
fastNLP/core/tester.py View File

@@ -1,9 +1,8 @@
import numpy as np import numpy as np
import torch import torch


from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.modules import utils
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.saver.logger import create_logger from fastNLP.saver.logger import create_logger


logger = create_logger(__name__, "./train_test.log") logger = create_logger(__name__, "./train_test.log")
@@ -35,16 +34,16 @@ class BaseTester(object):
""" """
"required_args" is the collection of arguments that users must pass to Trainer explicitly. "required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training. This is used to warn users of essential settings in the training.
Obviously, "required_args" is the subset of "default_args".
The value in "default_args" to the keys in "required_args" is simply for type check.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
""" """
# add required arguments here
required_args = {}
required_args = {"task" # one of ("seq_label", "text_classify")
}


for req_key in required_args: for req_key in required_args:
if req_key not in kwargs: if req_key not in kwargs:
logger.error("Tester lacks argument {}".format(req_key)) logger.error("Tester lacks argument {}".format(req_key))
raise ValueError("Tester lacks argument {}".format(req_key)) raise ValueError("Tester lacks argument {}".format(req_key))
self._task = kwargs["task"]


for key in default_args: for key in default_args:
if key in kwargs: if key in kwargs:
@@ -79,14 +78,14 @@ class BaseTester(object):
self._model = network self._model = network


# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
self.mode(network, test=True)
self.mode(network, is_test=True)
self.eval_history.clear() self.eval_history.clear()
self.batch_output.clear() self.batch_output.clear()


iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False))
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
step = 0 step = 0


for batch_x, batch_y in self.make_batch(iterator):
for batch_x, batch_y in data_iterator:
with torch.no_grad(): with torch.no_grad():
prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y) eval_results = self.evaluate(prediction, batch_y)
@@ -102,17 +101,22 @@ class BaseTester(object):
print(self.make_eval_output(prediction, eval_results)) print(self.make_eval_output(prediction, eval_results))
step += 1 step += 1


def mode(self, model, test):
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
:param test: bool, whether in test mode.
:param is_test: bool, whether in test mode or not.

""" """
Action.mode(model, test)
if is_test:
model.eval()
else:
model.train()


def data_forward(self, network, x): def data_forward(self, network, x):
"""A forward pass of the model. """ """A forward pass of the model. """
raise NotImplementedError
y = network(**x)
return y


def evaluate(self, predict, truth): def evaluate(self, predict, truth):
"""Compute evaluation metrics. """Compute evaluation metrics.
@@ -121,7 +125,38 @@ class BaseTester(object):
:param truth: Tensor :param truth: Tensor
:return eval_results: can be anything. It will be stored in self.eval_history :return eval_results: can be anything. It will be stored in self.eval_history
""" """
raise NotImplementedError
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))

if self._task == "seq_label":
return self._seq_label_evaluate(predict, truth)
elif self._task == "text_classify":
return self._text_classify_evaluate(predict, truth)
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

def _seq_label_evaluate(self, predict, truth):
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth) / batch_size
prediction = self._model.prediction(predict)
# pad prediction to equal length
for pred in prediction:
if len(pred) < max_len:
pred += [0] * (max_len - len(pred))
results = torch.Tensor(prediction).view(-1, )

# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]

def _text_classify_evaluate(self, y_logit, y_true):
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]


@property @property
def metrics(self): def metrics(self):
@@ -131,7 +166,27 @@ class BaseTester(object):


:return : variable number of outputs :return : variable number of outputs
""" """
raise NotImplementedError
if self._task == "seq_label":
return self._seq_label_metrics
elif self._task == "text_classify":
return self._text_classify_metrics
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

@property
def _seq_label_metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy

@property
def _text_classify_metrics(self):
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc


def show_metrics(self): def show_metrics(self):
"""Customize evaluation outputs in Trainer. """Customize evaluation outputs in Trainer.
@@ -140,10 +195,8 @@ class BaseTester(object):


:return print_str: str :return print_str: str
""" """
raise NotImplementedError

def make_batch(self, iterator):
raise NotImplementedError
loss, accuracy = self.metrics
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)


def make_eval_output(self, predictions, eval_results): def make_eval_output(self, predictions, eval_results):
"""Customize Tester outputs. """Customize Tester outputs.
@@ -152,108 +205,20 @@ class BaseTester(object):
:param eval_results: Tensor :param eval_results: Tensor
:return: str, to be printed. :return: str, to be printed.
""" """
raise NotImplementedError
return self.show_metrics()


class SeqLabelTester(BaseTester):
"""Tester for sequence labeling.

"""


class SeqLabelTester(BaseTester):
def __init__(self, **test_args): def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
super(SeqLabelTester, self).__init__(**test_args) super(SeqLabelTester, self).__init__(**test_args)
self.max_len = None
self.mask = None
self.seq_len = None

def data_forward(self, network, inputs):
"""This is only for sequence labeling with CRF decoder.

:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return y: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs, tuple):
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask
self.seq_len = seq_len
y = network(x)
return y

def evaluate(self, predict, truth):
"""Compute metrics (or loss).

:param predict: Tensor, [batch_size, max_len, tag_size]
:param truth: Tensor, [batch_size, max_len]
:return:
"""
batch_size, max_len = predict.size(0), predict.size(1)
loss = self._model.loss(predict, truth, self.mask) / batch_size

prediction = self._model.prediction(predict, self.mask)
results = torch.Tensor(prediction).view(-1, )
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [float(loss), float(accuracy)]

def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
batch_accuracy = np.mean([x[1] for x in self.eval_history])
return batch_loss, batch_accuracy

def show_metrics(self):
"""This is called by Trainer to print evaluation on dev set.

:return print_str: str
"""
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)

def make_batch(self, iterator):
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True)




class ClassificationTester(BaseTester): class ClassificationTester(BaseTester):
"""Tester for classification."""

def __init__(self, **test_args): def __init__(self, **test_args):
"""
:param test_args: a dict-like object that has __getitem__ method.
can be accessed by "test_args["key_str"]"
"""
test_args.update({"task": "seq_label"})
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
super(ClassificationTester, self).__init__(**test_args) super(ClassificationTester, self).__init__(**test_args)

def make_batch(self, iterator, max_len=None):
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def evaluate(self, y_logit, y_true):
"""Return y_pred and y_true."""
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

def metrics(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

+ 56
- 96
fastNLP/core/trainer.py View File

@@ -4,15 +4,13 @@ import time
from datetime import timedelta from datetime import timedelta


import torch import torch
import tensorboardX
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter


from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.action import RandomSampler
from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
from fastNLP.core.tester import SeqLabelTester, ClassificationTester from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.logger import create_logger from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver


@@ -50,16 +48,16 @@ class BaseTrainer(object):
""" """
"required_args" is the collection of arguments that users must pass to Trainer explicitly. "required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training. This is used to warn users of essential settings in the training.
Obviously, "required_args" is the subset of "default_args".
The value in "default_args" to the keys in "required_args" is simply for type check.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
""" """
# add required arguments here
required_args = {}
required_args = {"task" # one of ("seq_label", "text_classify")
}


for req_key in required_args: for req_key in required_args:
if req_key not in kwargs: if req_key not in kwargs:
logger.error("Trainer lacks argument {}".format(req_key)) logger.error("Trainer lacks argument {}".format(req_key))
raise ValueError("Trainer lacks argument {}".format(req_key)) raise ValueError("Trainer lacks argument {}".format(req_key))
self._task = kwargs["task"]


for key in default_args: for key in default_args:
if key in kwargs: if key in kwargs:
@@ -90,13 +88,14 @@ class BaseTrainer(object):
self._optimizer_proto = default_args["optimizer"] self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False self._graph_summaried = False
self._best_accuracy = 0.0


def train(self, network, train_data, dev_data=None): def train(self, network, train_data, dev_data=None):
"""General Training Procedure """General Training Procedure


:param network: a model :param network: a model
:param train_data: three-level list, the training set.
:param dev_data: three-level list, the validation data (optional)
:param train_data: a DataSet instance, the training data
:param dev_data: a DataSet instance, the validation data (optional)
""" """
# transfer model to gpu if available # transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
@@ -126,9 +125,10 @@ class BaseTrainer(object):
logger.info("training epoch {}".format(epoch)) logger.info("training epoch {}".format(epoch))


# turn on network training mode # turn on network training mode
self.mode(network, test=False)
self.mode(network, is_test=False)
# prepare mini-batch iterator # prepare mini-batch iterator
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False))
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda)
logger.info("prepared data iterator") logger.info("prepared data iterator")


# one forward and backward pass # one forward and backward pass
@@ -157,7 +157,7 @@ class BaseTrainer(object):
- epoch: int, - epoch: int,
""" """
step = 0 step = 0
for batch_x, batch_y in self.make_batch(data_iterator):
for batch_x, batch_y in data_iterator:


prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)


@@ -166,10 +166,6 @@ class BaseTrainer(object):
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) self._summary_writer.add_scalar("loss", loss.item(), global_step=step)


if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"])) diff = timedelta(seconds=round(end - kwargs["start"]))
@@ -204,11 +200,17 @@ class BaseTrainer(object):
network_copy = copy.deepcopy(network) network_copy = copy.deepcopy(network)
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) self.train(network_copy, train_data_cv[i], dev_data_cv[i])


def make_batch(self, iterator):
raise NotImplementedError
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param is_test: bool, whether in test mode or not.


def mode(self, network, test):
Action.mode(network, test)
"""
if is_test:
model.eval()
else:
model.train()


def define_optimizer(self): def define_optimizer(self):
"""Define framework-specific optimizer specified by the models. """Define framework-specific optimizer specified by the models.
@@ -224,7 +226,20 @@ class BaseTrainer(object):
self._optimizer.step() self._optimizer.step()


def data_forward(self, network, x): def data_forward(self, network, x):
raise NotImplementedError
if self._task == "seq_label":
y = network(x["word_seq"], x["word_seq_origin_len"])
elif self._task == "text_classify":
y = network(x["word_seq"])
else:
raise NotImplementedError("Unknown task type {}.".format(self._task))

if not self._graph_summaried:
if self._task == "seq_label":
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
elif self._task == "text_classify":
self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
self._graph_summaried = True
return y


def grad_backward(self, loss): def grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.
@@ -243,6 +258,13 @@ class BaseTrainer(object):
:param truth: ground truth label vector :param truth: ground truth label vector
:return: a scalar :return: a scalar
""" """
if "label_seq" in truth:
truth = truth["label_seq"]
elif "label" in truth:
truth = truth["label"]
truth = truth.view((-1,))
else:
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
return self._loss_func(predict, truth) return self._loss_func(predict, truth)


def define_loss(self): def define_loss(self):
@@ -270,7 +292,12 @@ class BaseTrainer(object):
:param validator: a Tester instance :param validator: a Tester instance
:return: bool, True means current results on dev set is the best. :return: bool, True means current results on dev set is the best.
""" """
raise NotImplementedError
loss, accuracy = validator.metrics
if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False


def save_model(self, network, model_name): def save_model(self, network, model_name):
"""Save this model with such a name. """Save this model with such a name.
@@ -291,55 +318,11 @@ class SeqLabelTrainer(BaseTrainer):
"""Trainer for Sequence Labeling """Trainer for Sequence Labeling


""" """

def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs.update({"task": "seq_label"})
print(
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
super(SeqLabelTrainer, self).__init__(**kwargs) super(SeqLabelTrainer, self).__init__(**kwargs)
# self.vocab_size = kwargs["vocab_size"]
# self.num_classes = kwargs["num_classes"]
self.max_len = None
self.mask = None
self.best_accuracy = 0.0

def data_forward(self, network, inputs):
if not isinstance(inputs, tuple):
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0])))
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]

batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)

if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask

y = network(x)
return y

def get_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.

:param predict: prediction label vector, [batch_size, max_len, tag_size]
:param truth: ground truth label vector, [batch_size, max_len]
:return loss: a scalar
"""
batch_size, max_len = predict.size(0), predict.size(1)
assert truth.shape == (batch_size, max_len)

loss = self._model.loss(predict, truth, self.mask)
return loss

def best_eval_result(self, validator):
loss, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False

def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda)


def _create_validator(self, valid_args): def _create_validator(self, valid_args):
return SeqLabelTester(**valid_args) return SeqLabelTester(**valid_args)
@@ -349,33 +332,10 @@ class ClassificationTrainer(BaseTrainer):
"""Trainer for text classification.""" """Trainer for text classification."""


def __init__(self, **train_args): def __init__(self, **train_args):
train_args.update({"task": "text_classify"})
print(
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
super(ClassificationTrainer, self).__init__(**train_args) super(ClassificationTrainer, self).__init__(**train_args)


self.iterator = None
self.loss_func = None
self.optimizer = None
self.best_accuracy = 0

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda)

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)

def best_eval_result(self, validator):
_, _, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False

def _create_validator(self, valid_args): def _create_validator(self, valid_args):
return ClassificationTester(**valid_args) return ClassificationTester(**valid_args)

+ 6
- 2
fastNLP/models/cnn_text_classification.py View File

@@ -35,8 +35,12 @@ class CNNText(torch.nn.Module):
self.dropout = nn.Dropout(drop_prob) self.dropout = nn.Dropout(drop_prob)
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)


def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
def forward(self, word_seq):
"""
:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return x: torch.LongTensor, [batch_size, num_classes]
"""
x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class] x = self.fc(x) # [N,C] -> [N, N_class]


+ 42
- 14
fastNLP/models/sequence_modeling.py View File

@@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder from fastNLP.modules import decoder, encoder




def seq_mask(seq_len, max_len):
"""Create a mask for the sequences.

:param seq_len: list or torch.LongTensor
:param max_len: int
:return mask: torch.LongTensor
"""
if isinstance(seq_len, list):
seq_len = torch.LongTensor(seq_len)
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)]
mask = torch.stack(mask, 1)
return mask


class SeqLabeling(BaseModel): class SeqLabeling(BaseModel):
""" """
PyTorch Network for sequence labeling PyTorch Network for sequence labeling
@@ -20,13 +34,17 @@ class SeqLabeling(BaseModel):
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None


def forward(self, x):
def forward(self, word_seq, word_seq_origin_len):
""" """
:param x: LongTensor, [batch_size, mex_len]
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
:return y: [batch_size, mex_len, tag_size] :return y: [batch_size, mex_len, tag_size]
""" """
x = self.Embedding(x)
self.mask = self.make_mask(word_seq, word_seq_origin_len)

x = self.Embedding(word_seq)
# [batch_size, max_len, word_emb_dim] # [batch_size, max_len, word_emb_dim]
x = self.Rnn(x) x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction] # [batch_size, max_len, hidden_size * direction]
@@ -34,27 +52,34 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, num_classes] # [batch_size, max_len, num_classes]
return x return x


def loss(self, x, y, mask):
def loss(self, x, y):
""" """
Negative log likelihood loss. Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size] :param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len] :param y: Tensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, ,max_len]
:return loss: a scalar Tensor :return loss: a scalar Tensor


""" """
x = x.float() x = x.float()
y = y.long() y = y.long()
total_loss = self.Crf(x, y, mask)
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.Crf(x, y, self.mask)
return torch.mean(total_loss) return torch.mean(total_loss)


def prediction(self, x, mask):
def make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
mask = mask.to(x)
return mask

def prediction(self, x):
""" """
:param x: FloatTensor, [batch_size, max_len, tag_size] :param x: FloatTensor, [batch_size, max_len, tag_size]
:param mask: ByteTensor, [batch_size, max_len]
:return prediction: list of [decode path(list)] :return prediction: list of [decode path(list)]
""" """
tag_seq = self.Crf.viterbi_decode(x, mask)
tag_seq = self.Crf.viterbi_decode(x, self.mask)
return tag_seq return tag_seq




@@ -81,14 +106,17 @@ class AdvSeqLabel(SeqLabeling):


self.Crf = decoder.CRF.ConditionalRandomField(num_classes) self.Crf = decoder.CRF.ConditionalRandomField(num_classes)


def forward(self, x):
def forward(self, word_seq, word_seq_origin_len):
""" """
:param x: LongTensor, [batch_size, mex_len]
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
:return y: [batch_size, mex_len, tag_size] :return y: [batch_size, mex_len, tag_size]
""" """
batch_size = x.size(0)
max_len = x.size(1)
x = self.Embedding(x)
self.mask = self.make_mask(word_seq, word_seq_origin_len)

batch_size = word_seq.size(0)
max_len = word_seq.size(1)
x = self.Embedding(word_seq)
# [batch_size, max_len, word_emb_dim] # [batch_size, max_len, word_emb_dim]
x = self.Rnn(x) x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction] # [batch_size, max_len, hidden_size * direction]


+ 40
- 11
fastNLP/modules/aggregation/self_attention.py View File

@@ -1,8 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F




from fastNLP.modules.utils import initial_parameter
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
""" """
Self Attention Module. Self Attention Module.
@@ -13,13 +15,18 @@ class SelfAttention(nn.Module):
num_vec: int, the number of encoded vectors num_vec: int, the number of encoded vectors
""" """


def __init__(self, input_size, dim=10, num_vec=10):
def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
# self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
# self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
self.attention_hops = num_vec

self.ws1 = nn.Linear(input_size, dim, bias=False)
self.ws2 = nn.Linear(dim, num_vec, bias=False)
self.drop = nn.Dropout(drop)
self.softmax = nn.Softmax(dim=2) self.softmax = nn.Softmax(dim=2)
self.tanh = nn.Tanh() self.tanh = nn.Tanh()

initial_parameter(self, initial_method)
def penalization(self, A): def penalization(self, A):
""" """
compute the penalization term for attention module compute the penalization term for attention module
@@ -32,11 +39,33 @@ class SelfAttention(nn.Module):
M = M.view(M.size(0), -1) M = M.view(M.size(0), -1)
return torch.sum(M ** 2, dim=1) return torch.sum(M ** 2, dim=1)
def forward(self, x):
inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2)))
A = self.softmax(torch.matmul(self.W_s2, inter))
out = torch.matmul(A, x)
out = out.view(out.size(0), -1)
penalty = self.penalization(A)
return out, penalty
def forward(self, outp ,inp):
# the following code can not be use because some word are padding ,these is not such module!

# inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # []
# A = self.softmax(torch.matmul(self.W_s2, inter))
# out = torch.matmul(A, x)
# out = out.view(out.size(0), -1)
# penalty = self.penalization(A)
# return out, penalty
outp = outp.contiguous()
size = outp.size() # [bsz, len, nhid]

compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]

hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len]
penalized_alphas = attention + (
-10000 * (concatenated_inp == 0).float())
# [bsz, hop, len] + [bsz, hop, len]
attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid]





+ 4
- 3
fastNLP/modules/decoder/CRF.py View File

@@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn


from fastNLP.modules.utils import initial_parameter


def log_sum_exp(x, dim=-1): def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True) max_value, _ = x.max(dim=dim, keepdim=True)
@@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens):




class ConditionalRandomField(nn.Module): class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=True):
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None):
""" """
:param tag_size: int, num of tags :param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag :param include_start_end_trans: bool, whether to include start/end tag
@@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module):
self.start_scores = nn.Parameter(torch.randn(tag_size)) self.start_scores = nn.Parameter(torch.randn(tag_size))
self.end_scores = nn.Parameter(torch.randn(tag_size)) self.end_scores = nn.Parameter(torch.randn(tag_size))


self.reset_parameter()
# self.reset_parameter()
initial_parameter(self, initial_method)
def reset_parameter(self): def reset_parameter(self):
nn.init.xavier_normal_(self.transition_m) nn.init.xavier_normal_(self.transition_m)
if self.include_start_end_trans: if self.include_start_end_trans:


+ 3
- 3
fastNLP/modules/decoder/MLP.py View File

@@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from fastNLP.modules.utils import initial_parameter
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, size_layer, num_class=2, activation='relu'):
def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None):
"""Multilayer Perceptrons as a decoder """Multilayer Perceptrons as a decoder


Args: Args:
@@ -36,7 +36,7 @@ class MLP(nn.Module):
self.hidden_active = activation self.hidden_active = activation
else: else:
raise ValueError("should set activation correctly: {}".format(activation)) raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method )
def forward(self, x): def forward(self, x):
for layer in self.hiddens: for layer in self.hiddens:
x = self.hidden_active(layer(x)) x = self.hidden_active(layer(x))


+ 7
- 4
fastNLP/modules/encoder/char_embedding.py View File

@@ -1,11 +1,12 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
# from torch.nn.init import xavier_uniform


from fastNLP.modules.utils import initial_parameter
class ConvCharEmbedding(nn.Module): class ConvCharEmbedding(nn.Module):


def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)):
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None):
""" """
Character Level Word Embedding Character Level Word Embedding
:param char_emb_size: the size of character level embedding. Default: 50 :param char_emb_size: the size of character level embedding. Default: 50
@@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module):
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
for i in range(len(kernels))]) for i in range(len(kernels))])


initial_parameter(self,initial_method)

def forward(self, x): def forward(self, x):
""" """
:param x: [batch_size * sent_length, word_length, char_emb_size] :param x: [batch_size * sent_length, word_length, char_emb_size]
@@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module):
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
""" """


def __init__(self, char_emb_size=50, hidden_size=None):
def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None):
super(LSTMCharEmbedding, self).__init__() super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size self.hidden_size = char_emb_size if hidden_size is None else hidden_size


@@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module):
num_layers=1, num_layers=1,
bias=True, bias=True,
batch_first=True) batch_first=True)
initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
""" """
:param x:[ n_batch*n_word, word_length, char_emb_size] :param x:[ n_batch*n_word, word_length, char_emb_size]


+ 4
- 2
fastNLP/modules/encoder/conv.py View File

@@ -6,6 +6,7 @@ import torch.nn as nn
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
# import torch.nn.functional as F # import torch.nn.functional as F


from fastNLP.modules.utils import initial_parameter


class Conv(nn.Module): class Conv(nn.Module):
""" """
@@ -15,7 +16,7 @@ class Conv(nn.Module):


def __init__(self, in_channels, out_channels, kernel_size, def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu'):
groups=1, bias=True, activation='relu',initial_method = None ):
super(Conv, self).__init__() super(Conv, self).__init__()
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,
@@ -26,7 +27,7 @@ class Conv(nn.Module):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias) bias=bias)
xavier_uniform_(self.conv.weight)
# xavier_uniform_(self.conv.weight)


activations = { activations = {
'relu': nn.ReLU(), 'relu': nn.ReLU(),
@@ -37,6 +38,7 @@ class Conv(nn.Module):
raise Exception( raise Exception(
'Should choose activation function from: ' + 'Should choose activation function from: ' +
', '.join([x for x in activations])) ', '.join([x for x in activations]))
initial_parameter(self, initial_method)


def forward(self, x): def forward(self, x):
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]


+ 4
- 2
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
from fastNLP.modules.utils import initial_parameter


class ConvMaxpool(nn.Module): class ConvMaxpool(nn.Module):
""" """
@@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module):


def __init__(self, in_channels, out_channels, kernel_sizes, def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu'):
groups=1, bias=True, activation='relu',initial_method = None ):
super(ConvMaxpool, self).__init__() super(ConvMaxpool, self).__init__()


# convolution # convolution
@@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module):
raise Exception( raise Exception(
"Undefined activation function: choose from: relu") "Undefined activation function: choose from: relu")


initial_parameter(self, initial_method)

def forward(self, x): def forward(self, x):
# [N,L,C] -> [N,C,L] # [N,L,C] -> [N,C,L]
x = torch.transpose(x, 1, 2) x = torch.transpose(x, 1, 2)


+ 3
- 3
fastNLP/modules/encoder/linear.py View File

@@ -1,6 +1,6 @@
import torch.nn as nn import torch.nn as nn


from fastNLP.modules.utils import initial_parameter
class Linear(nn.Module): class Linear(nn.Module):
""" """
Linear module Linear module
@@ -12,10 +12,10 @@ class Linear(nn.Module):
bidirectional : If True, becomes a bidirectional RNN bidirectional : If True, becomes a bidirectional RNN
""" """


def __init__(self, input_size, output_size, bias=True):
def __init__(self, input_size, output_size, bias=True,initial_method = None ):
super(Linear, self).__init__() super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias) self.linear = nn.Linear(input_size, output_size, bias)
initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
x = self.linear(x) x = self.linear(x)
return x return x

+ 5
- 3
fastNLP/modules/encoder/lstm.py View File

@@ -1,6 +1,6 @@
import torch.nn as nn import torch.nn as nn


from fastNLP.modules.utils import initial_parameter
class Lstm(nn.Module): class Lstm(nn.Module):
""" """
LSTM module LSTM module
@@ -13,11 +13,13 @@ class Lstm(nn.Module):
bidirectional : If True, becomes a bidirectional RNN. Default: False. bidirectional : If True, becomes a bidirectional RNN. Default: False.
""" """


def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False):
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None):
super(Lstm, self).__init__() super(Lstm, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional) dropout=dropout, bidirectional=bidirectional)
initial_parameter(self, initial_method)
def forward(self, x): def forward(self, x):
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x
if __name__ == "__main__":
lstm = Lstm(10)

+ 3
- 3
fastNLP/modules/encoder/masked_rnn.py View File

@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


from fastNLP.modules.utils import initial_parameter
def MaskedRecurrent(reverse=False): def MaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask, train=True, dropout=0): def forward(input, hidden, cell, mask, train=True, dropout=0):
""" """
@@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False):
class MaskedRNNBase(nn.Module): class MaskedRNNBase(nn.Module):
def __init__(self, Cell, input_size, hidden_size, def __init__(self, Cell, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False, num_layers=1, bias=True, batch_first=False,
layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs):
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs):
""" """
:param Cell: :param Cell:
:param input_size: :param input_size:
@@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module):
cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs)
self.all_cells.append(cell) self.all_cells.append(cell)
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看
initial_parameter(self, initial_method)
def reset_parameters(self): def reset_parameters(self):
for cell in self.all_cells: for cell in self.all_cells:
cell.reset_parameters() cell.reset_parameters()


+ 5
- 4
fastNLP/modules/encoder/variational_rnn.py View File

@@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter


from fastNLP.modules.utils import initial_parameter


def default_initializer(hidden_size): def default_initializer(hidden_size):
stdv = 1.0 / math.sqrt(hidden_size) stdv = 1.0 / math.sqrt(hidden_size)
@@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False):
class VarMaskedRNNBase(nn.Module): class VarMaskedRNNBase(nn.Module):
def __init__(self, Cell, input_size, hidden_size, def __init__(self, Cell, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False, num_layers=1, bias=True, batch_first=False,
dropout=(0, 0), bidirectional=False, initializer=None, **kwargs):
dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs):


super(VarMaskedRNNBase, self).__init__() super(VarMaskedRNNBase, self).__init__()
self.Cell = Cell self.Cell = Cell
@@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module):
cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs)
self.all_cells.append(cell) self.all_cells.append(cell)
self.add_module('cell%d' % (layer * num_directions + direction), cell) self.add_module('cell%d' % (layer * num_directions + direction), cell)
initial_parameter(self, initial_method)
def reset_parameters(self): def reset_parameters(self):
for cell in self.all_cells: for cell in self.all_cells:
cell.reset_parameters() cell.reset_parameters()
@@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase):
\end{array} \end{array}
""" """


def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None):
def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None):
super(VarFastLSTMCell, self).__init__() super(VarFastLSTMCell, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
@@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase):
self.p_hidden = p_hidden self.p_hidden = p_hidden
self.noise_in = None self.noise_in = None
self.noise_hidden = None self.noise_hidden = None
initial_parameter(self, initial_method)
def reset_parameters(self): def reset_parameters(self):
for weight in self.parameters(): for weight in self.parameters():
if weight.dim() == 1: if weight.dim() == 1:


+ 47
- 2
fastNLP/modules/utils.py View File

@@ -2,8 +2,8 @@ from collections import defaultdict


import numpy as np import numpy as np
import torch import torch
import torch.nn.init as init
import torch.nn as nn
def mask_softmax(matrix, mask): def mask_softmax(matrix, mask):
if mask is None: if mask is None:
result = torch.nn.functional.softmax(matrix, dim=-1) result = torch.nn.functional.softmax(matrix, dim=-1)
@@ -11,6 +11,51 @@ def mask_softmax(matrix, mask):
raise NotImplementedError raise NotImplementedError
return result return result


def initial_parameter(net ,initial_method =None):

if initial_method == 'xavier_uniform':
init_method = init.xavier_uniform_
elif initial_method=='xavier_normal':
init_method = init.xavier_normal_
elif initial_method == 'kaiming_normal' or initial_method =='msra':
init_method = init.kaiming_normal
elif initial_method == 'kaiming_uniform':
init_method = init.kaiming_normal
elif initial_method == 'orthogonal':
init_method = init.orthogonal_
elif initial_method == 'sparse':
init_method = init.sparse_
elif initial_method =='normal':
init_method = init.normal_
elif initial_method =='uniform':
initial_method = init.uniform_
else:
init_method = init.xavier_normal_
def weights_init(m):
# classname = m.__class__.__name__
if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn
if initial_method != None:
init_method(m.weight.data)
else:
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.LSTM):
for w in m.parameters():
if len(w.data.size())>1:
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
elif hasattr(m, 'weight') and m.weight.requires_grad:
init_method(m.weight.data)
else:
for w in m.parameters() :
if w.requires_grad:
if len(w.data.size())>1:
init_method(w.data) # weight
else:
init.normal_(w.data) # bias
# print("init else")
net.apply(weights_init)


def seq_mask(seq_len, max_len): def seq_mask(seq_len, max_len):
mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)]


+ 13
- 0
reproduction/LSTM+self_attention_sentiment_analysis/config.cfg View File

@@ -0,0 +1,13 @@
[train]
epochs = 30
batch_size = 32
pickle_path = "./save/"
validate = true
save_best_dev = true
model_saved_path = "./save/"
rnn_hidden_units = 300
word_emb_dim = 300
use_crf = true
use_cuda = false
loss_func = "cross_entropy"
num_classes = 5

+ 80
- 0
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -0,0 +1,80 @@

import os

import torch.nn.functional as F

from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader
from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader
from fastNLP.loader.config_loader import ConfigSection
from fastNLP.loader.config_loader import ConfigLoader

from fastNLP.models.base_model import BaseModel

from fastNLP.core.preprocess import ClassPreprocess as Preprocess
from fastNLP.core.trainer import ClassificationTrainer

from fastNLP.modules.encoder.embedding import Embedding as Embedding
from fastNLP.modules.encoder.lstm import Lstm
from fastNLP.modules.aggregation.self_attention import SelfAttention
from fastNLP.modules.decoder.MLP import MLP


train_data_path = 'small_train_data.txt'
dev_data_path = 'small_dev_data.txt'
# emb_path = 'glove.txt'

lstm_hidden_size = 300
embeding_size = 300
attention_unit = 350
attention_hops = 10
class_num = 5
nfc = 3000
### data load ###
train_dataset = Dataset_loader(train_data_path)
train_data = train_dataset.load()

dev_args = Dataset_loader(dev_data_path)
dev_data = dev_args.load()

###### preprocess ####
preprocess = Preprocess()
word2index, label2index = preprocess.build_dict(train_data)
train_data, dev_data = preprocess.run(train_data, dev_data)



# emb = EmbedLoader(emb_path)
# embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index)
### construct vocab ###

class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
def __init__(self, args=None):
super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__()
self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None )
self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True)
self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops)
self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,)
def forward(self,x):
x_emb = self.embedding(x)
output = self.lstm(x_emb)
after_attention, penalty = self.attention(output,x)
after_attention =after_attention.view(after_attention.size(0),-1)
output = self.mlp(after_attention)
return output

def loss(self, predict, ground_truth):
print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size())))
print(ground_truth)
return F.cross_entropy(predict, ground_truth)

train_args = ConfigSection()
ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
train_args['vocab'] = len(word2index)


trainer = ClassificationTrainer(**train_args.data)

# for k in train_args.__dict__.keys():
# print(k, train_args[k])
model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
trainer.train(model,train_data , dev_data)

+ 4
- 4
setup.py View File

@@ -2,18 +2,18 @@
# coding=utf-8 # coding=utf-8
from setuptools import setup, find_packages from setuptools import setup, find_packages


with open('README.md') as f:
with open('README.md', encoding='utf-8') as f:
readme = f.read() readme = f.read()


with open('LICENSE') as f:
with open('LICENSE', encoding='utf-8') as f:
license = f.read() license = f.read()


with open('requirements.txt') as f:
with open('requirements.txt', encoding='utf-8') as f:
reqs = f.read() reqs = f.read()


setup( setup(
name='fastNLP', name='fastNLP',
version='0.0.1',
version='0.0.3',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme, long_description=readme,
license=license, license=license,


+ 0
- 17
test/core/test_action.py View File

@@ -1,17 +0,0 @@
import unittest

from fastNLP.core.action import Action, Batchifier, SequentialSampler


class TestAction(unittest.TestCase):
def test_case_1(self):
x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [1, 1, 1, 1, 2, 2, 2, 2]
data = []
for i in range(len(x)):
data.append([[x[i]], [y[i]]])
data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False)
action = Action()
for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None):
print(batch_x)


+ 62
- 0
test/core/test_batch.py View File

@@ -0,0 +1,62 @@
import unittest

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet, create_dataset_from_lists
from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance

raw_texts = ["i am a cat",
"this is a test of new batch",
"ha ha",
"I am a good boy .",
"This is the most beautiful girl ."
]
texts = [text.strip().split() for text in raw_texts]
labels = [0, 1, 0, 0, 1]

# prepare vocabulary
vocab = {}
for text in texts:
for tokens in text:
if tokens not in vocab:
vocab[tokens] = len(vocab)


class TestCase1(unittest.TestCase):
def test(self):
data = DataSet()
for text, label in zip(texts, labels):
x = TextField(text, is_target=False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
data.append(ins)

# use vocabulary to index data
data.index_field("text", vocab)

# define naive sampler for batch class
class SeqSampler:
def __call__(self, dataset):
return list(range(len(dataset)))

# use batch to iterate dataset
data_iterator = Batch(data, 2, SeqSampler(), False)
for batch_x, batch_y in data_iterator:
self.assertEqual(len(batch_x), 2)
self.assertTrue(isinstance(batch_x, dict))
self.assertTrue(isinstance(batch_x["text"], torch.LongTensor))
self.assertTrue(isinstance(batch_y, dict))
self.assertTrue(isinstance(batch_y["label"], torch.LongTensor))


class TestCase2(unittest.TestCase):
def test(self):
data = DataSet()
for text in texts:
x = TextField(text, is_target=False)
ins = Instance(text=x)
data.append(ins)
data_set = create_dataset_from_lists(texts, vocab, has_target=False)
self.assertTrue(type(data) == type(data_set))

+ 51
- 0
test/core/test_predictor.py View File

@@ -0,0 +1,51 @@
import os
import unittest

from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle
from fastNLP.models.sequence_modeling import SeqLabeling


class TestPredictor(unittest.TestCase):
def test_seq_label(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}

infer_data = [
['a', 'b', 'c', 'd', 'e'],
['a', '@', 'c', 'd', 'e'],
['a', 'b', '#', 'd', 'e'],
['a', 'b', 'c', '?', 'e'],
['a', 'b', 'c', 'd', '$'],
['!', 'b', 'c', 'd', 'e']
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}

os.system("mkdir save")
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
save_pickle(vocab, "./save/", "word2id.pkl")

model = SeqLabeling(model_args)
predictor = Predictor("./save/", task="seq_label")

results = predictor.predict(network=model, data=infer_data)

self.assertTrue(isinstance(results, list))
self.assertGreater(len(results), 0)
for res in results:
self.assertTrue(isinstance(res, list))
self.assertEqual(len(res), 5)
self.assertTrue(isinstance(res[0], str))

os.system("rm -rf save")
print("pickle path deleted")


class TestPredictor2(unittest.TestCase):
def test_text_classify(self):
# TODO
pass

+ 49
- 20
test/core/test_preprocess.py View File

@@ -1,24 +1,25 @@
import os import os
import unittest import unittest


from fastNLP.core.dataset import DataSet
from fastNLP.core.preprocess import SeqLabelPreprocess from fastNLP.core.preprocess import SeqLabelPreprocess


data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]


class TestSeqLabelPreprocess(unittest.TestCase):
def test_case_1(self):
data = [
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
[['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']],
[['Hello', 'world', '!'], ['a', 'n', '.']],
]


class TestCase1(unittest.TestCase):
def test(self):
if os.path.exists("./save"): if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False): for root, dirs, files in os.walk("./save", topdown=False):
for name in files: for name in files:
@@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase):
os.rmdir(os.path.join(root, name)) os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save") pickle_path="./save")
result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4,
pickle_path="./save")
self.assertEqual(len(result), 2)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)

os.system("rm -rf save")
print("pickle path deleted")


class TestCase2(unittest.TestCase):
def test(self):
if os.path.exists("./save"): if os.path.exists("./save"):
for root, dirs, files in os.walk("./save", topdown=False): for root, dirs, files in os.walk("./save", topdown=False):
for name in files: for name in files:
os.remove(os.path.join(root, name)) os.remove(os.path.join(root, name))
for name in dirs: for name in dirs:
os.rmdir(os.path.join(root, name)) os.rmdir(os.path.join(root, name))
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True)
result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4, pickle_path="./save", train_dev_split=0.4,
cross_val=True)
cross_val=False)
self.assertEqual(len(result), 3)
self.assertEqual(type(result[0]), DataSet)
self.assertEqual(type(result[1]), DataSet)
self.assertEqual(type(result[2]), DataSet)

os.system("rm -rf save")
print("pickle path deleted")


class TestCase3(unittest.TestCase):
def test(self):
num_folds = 2
result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data,
pickle_path="./save", train_dev_split=0.4,
cross_val=True, n_fold=num_folds)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]), num_folds)
self.assertEqual(len(result[1]), num_folds)
for data_set in result[0] + result[1]:
self.assertEqual(type(data_set), DataSet)

os.system("rm -rf save")
print("pickle path deleted")

+ 48
- 30
test/core/test_tester.py View File

@@ -1,37 +1,55 @@
from fastNLP.core.preprocess import SeqLabelPreprocess
import os
import unittest

from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.models.sequence_modeling import SeqLabeling


data_name = "pku_training.utf8" data_name = "pku_training.utf8"
pickle_path = "data_for_tests" pickle_path = "data_for_tests"




def foo():
loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
train_data = loader.load_pku()

train_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})

# Preprocessor
p = SeqLabelPreprocess()
train_data = p.run(train_data)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes

model = SeqLabeling(train_args)

valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
"use_cuda": True}
validator = SeqLabelTester(**valid_args)

print("start validation.")
validator.test(model, train_data)
print(validator.show_metrics())


if __name__ == "__main__":
foo()
class TestTester(unittest.TestCase):
def test_case_1(self):
model_args = {
"vocab_size": 10,
"word_emb_dim": 100,
"rnn_hidden_units": 100,
"num_classes": 5
}
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 2, "pickle_path": "./save/",
"use_cuda": False, "print_every_step": 1}

train_data = [
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
]
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)

data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)

model = SeqLabeling(model_args)

tester = SeqLabelTester(**valid_args)
tester.test(network=model, dev_data=data_set)
# If this can run, everything is OK.

os.system("rm -rf save")
print("pickle path deleted")

+ 36
- 15
test/core/test_trainer.py View File

@@ -1,33 +1,54 @@
import os import os

import torch.nn as nn
import unittest import unittest


from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField
from fastNLP.core.instance import Instance
from fastNLP.core.loss import Loss from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.models.sequence_modeling import SeqLabeling



class TestTrainer(unittest.TestCase): class TestTrainer(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl", "save_best_dev": True, "model_name": "default_model_name.pkl",
"loss": Loss(None), "loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"vocab_size": 20,
"vocab_size": 10,
"word_emb_dim": 100, "word_emb_dim": 100,
"rnn_hidden_units": 100, "rnn_hidden_units": 100,
"num_classes": 3
"num_classes": 5
} }
trainer = SeqLabelTrainer()
trainer = SeqLabelTrainer(**args)

train_data = [ train_data = [
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
[[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
[[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
[['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
[['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
[['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
] ]
dev_data = train_data
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

data_set = DataSet()
for example in train_data:
text, label = example[0], example[1]
x = TextField(text, False)
y = TextField(label, is_target=True)
ins = Instance(word_seq=x, label_seq=y)
data_set.append(ins)

data_set.index_field("word_seq", vocab)
data_set.index_field("label_seq", label_vocab)

model = SeqLabeling(args) model = SeqLabeling(args)
trainer.train(network=model, train_data=train_data, dev_data=dev_data)

trainer.train(network=model, train_data=data_set, dev_data=data_set)
# If this can run, everything is OK.

os.system("rm -rf save")
print("pickle path deleted")

+ 7
- 7
test/model/seq_labeling.py View File

@@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer


parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt",
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt",
help="path to the training data") help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt",
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt",
help="data used for inference") help="data used for inference")


args = parser.parse_args() args = parser.parse_args()
@@ -86,7 +86,7 @@ def train_and_test():
trainer = SeqLabelTrainer( trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"], epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"], batch_size=trainer_args["batch_size"],
validate=trainer_args["validate"],
validate=False,
use_cuda=trainer_args["use_cuda"], use_cuda=trainer_args["use_cuda"],
pickle_path=pickle_path, pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"], save_best_dev=trainer_args["save_best_dev"],
@@ -121,7 +121,7 @@ def train_and_test():


# Tester # Tester
tester = SeqLabelTester(save_output=False, tester = SeqLabelTester(save_output=False,
save_loss=False,
save_loss=True,
save_best_dev=False, save_best_dev=False,
batch_size=4, batch_size=4,
use_cuda=False, use_cuda=False,
@@ -139,5 +139,5 @@ def train_and_test():




if __name__ == "__main__": if __name__ == "__main__":
# train_and_test()
infer()
train_and_test()
# infer()

+ 0
- 8
test/model/test_charlm.py View File

@@ -1,8 +0,0 @@


def test_charlm():
pass


if __name__ == "__main__":
test_charlm()

+ 85
- 0
test/model/test_seq_label.py View File

@@ -0,0 +1,85 @@
import os

from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import SeqLabelPreprocess
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver

pickle_path = "./seq_label/"
model_name = "seq_label_model.pkl"
config_dir = "test/data_for_tests/config"
data_path = "test/data_for_tests/people.txt"
data_infer_path = "test/data_for_tests/people_infer.txt"


def test_training():
# Config Loader
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader("_").load_config(config_dir, {
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

# Data Loader
pos_loader = POSDatasetLoader(data_path)
train_data = pos_loader.load_lines()

# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
model_args["vocab_size"] = p.vocab_size
model_args["num_classes"] = p.num_classes

trainer = SeqLabelTrainer(
epochs=trainer_args["epochs"],
batch_size=trainer_args["batch_size"],
validate=False,
use_cuda=False,
pickle_path=pickle_path,
save_best_dev=trainer_args["save_best_dev"],
model_name=model_name,
optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
)

# Model
model = SeqLabeling(model_args)

# Start training
trainer.train(model, data_train, data_dev)

# Saver
saver = ModelSaver(os.path.join(pickle_path, model_name))
saver.save_pytorch(model)

del model, trainer, pos_loader

# Define the same model
model = SeqLabeling(model_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))

# Load test configuration
tester_args = ConfigSection()
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})

# Tester
tester = SeqLabelTester(save_output=False,
save_loss=True,
save_best_dev=False,
batch_size=4,
use_cuda=False,
pickle_path=pickle_path,
model_name="seq_label_in_test.pkl",
print_every_step=1
)

# Start testing with validation data
tester.test(model, data_dev)

loss, accuracy = tester.metrics
assert 0 < accuracy < 1

+ 3
- 3
test/model/text_classify.py View File

@@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss


parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt",
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt",
help="path to the training data") help="path to the training data")
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")


args = parser.parse_args() args = parser.parse_args()
@@ -115,4 +115,4 @@ def train():


if __name__ == "__main__": if __name__ == "__main__":
train() train()
infer()
# infer()

Loading…
Cancel
Save