Browse Source

Updates to cores, action, loader:

- rename Inference to Predictor
- rename Trainer.prepare_input to Trainer.load_train_data, load data_train.pkl only
- add __contains__ method to config Section class
- more code comments
- more elegant make_batch & data_iterator: Samplers return batch samples instead of batch indices
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
4bbeaebe96
17 changed files with 251 additions and 195 deletions
  1. +0
    -5
      fastNLP/action/optimizer.py
  2. +68
    -47
      fastNLP/core/action.py
  3. +29
    -33
      fastNLP/core/metrics.py
  4. +2
    -4
      fastNLP/core/optimizer.py
  5. +53
    -39
      fastNLP/core/predictor.py
  6. +31
    -8
      fastNLP/core/tester.py
  7. +44
    -44
      fastNLP/core/trainer.py
  8. +1
    -1
      fastNLP/fastnlp.py
  9. +3
    -0
      fastNLP/loader/config_loader.py
  10. +1
    -1
      fastNLP/loader/embed_loader.py
  11. +6
    -4
      fastNLP/modules/utils.py
  12. +2
    -2
      reproduction/chinese_word_seg/cws_train.py
  13. +2
    -0
      test/__init__.py
  14. +1
    -1
      test/ner_decode.py
  15. +3
    -3
      test/seq_labeling.py
  16. +2
    -2
      test/test_cws.py
  17. +3
    -1
      test/text_classify.py

+ 0
- 5
fastNLP/action/optimizer.py View File

@@ -1,5 +0,0 @@
'''
use optimizer from Pytorch
'''

from torch.optim import *

+ 68
- 47
fastNLP/core/action.py View File

@@ -10,7 +10,7 @@ import torch

class Action(object):
"""
Operations shared by Trainer, Tester, and Inference.
Operations shared by Trainer, Tester, or Inference.
This is designed for reducing replicate codes.
- make_batch: produce a min-batch of data. @staticmethod
- pad: padding method used in sequence modeling. @staticmethod
@@ -22,28 +22,24 @@ class Action(object):
super(Action, self).__init__()

@staticmethod
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None):
def make_batch(iterator, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:param use_cuda: bool
:param output_length: whether to output the original length of the sequence before padding.
:param max_len: int, maximum sequence length
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: True)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:return
if output_length is True:
(batch_x, seq_len): tuple of two elements
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

return batch_x and batch_y, if output_length is False
if output_length is False:
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
for indices in iterator:
batch = [data[idx] for idx in indices]
for batch in iterator:
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]

@@ -68,11 +64,11 @@ class Action(object):

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length of this batch.
""" Pad a mini-batch of sequence samples to maximum length of this batch.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
@@ -95,11 +91,10 @@ class Action(object):

def convert_to_torch_tensor(data_list, use_cuda):
"""
convert lists into (cuda) Tensors
convert lists into (cuda) Tensors.
:param data_list: 2-level lists
:param use_cuda: bool
:param reqired_grad: bool
:return: PyTorch Tensor of shape [batch_size, max_seq_len]
:param use_cuda: bool, whether to use GPU or not
:return data_list: PyTorch Tensor of shape [batch_size, max_seq_len]
"""
data_list = torch.Tensor(data_list).long()
if torch.cuda.is_available() and use_cuda:
@@ -171,6 +166,7 @@ class BaseSampler(object):

def __init__(self, data_set):
self.data_set_length = len(data_set)
self.data = data_set

def __len__(self):
return self.data_set_length
@@ -188,7 +184,7 @@ class SequentialSampler(BaseSampler):
super(SequentialSampler, self).__init__(data_set)

def __iter__(self):
return iter(range(self.data_set_length))
return iter(self.data)


class RandomSampler(BaseSampler):
@@ -198,28 +194,10 @@ class RandomSampler(BaseSampler):

def __init__(self, data_set):
super(RandomSampler, self).__init__(data_set)
self.order = np.random.permutation(self.data_set_length)

def __iter__(self):
return iter(np.random.permutation(self.data_set_length))


class BucketSampler(BaseSampler):
"""
Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths.
"""

def __init__(self, data_set):
super(BucketSampler, self).__init__(data_set)
BUCKETS = ([None] * 20)
self.length_freq = dict(Counter([len(example) for example in data_set]))
self.buckets = k_means_bucketing(data_set, BUCKETS)

def __iter__(self):
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))]
np.random.shuffle(bucket_samples)
return iter(bucket_samples)
return iter((self.data[idx] for idx in self.order))


class Batchifier(object):
@@ -235,10 +213,53 @@ class Batchifier(object):

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
for example in self.sampler:
batch.append(example)
if len(batch) == self.batch_size:
yield batch
batch = []
if 0 < len(batch) < self.batch_size and self.drop_last is False:
yield batch


class BucketBatchifier(Batchifier):
"""
Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths.
"""

def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None):
"""

:param data_set: three-level list, shape [num_samples, 2]
:param batch_size: int
:param num_buckets: int, number of buckets for grouping these sequences.
:param drop_last: bool, useless currently.
:param sampler: Sampler, useless currently.
"""
super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last)
buckets = ([None] * num_buckets)
self.data = data_set
self.batch_size = batch_size
self.length_freq = dict(Counter([len(example) for example in data_set]))
self.buckets = k_means_bucketing(data_set, buckets)

def __iter__(self):
"""Make a min-batch of data."""
for _ in range(len(self.data) // self.batch_size):
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))]
np.random.shuffle(bucket_samples)
yield [self.data[idx] for idx in bucket_samples[:batch_size]]


if __name__ == "__main__":
import random

data = [[[y] * random.randint(0, 50), [y]] for y in range(500)]
batch_size = 8
iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5))
for d in iterator:
print("\nbatch:")
for dd in d:
print(len(dd[0]), end=" ")

+ 29
- 33
fastNLP/core/metrics.py View File

@@ -1,62 +1,55 @@
"""
To do:
设计评判结果的各种指标。如果涉及向量,使用numpy。
参考http://scikit-learn.org/stable/modules/classes.html#classification-metrics
建议是每种metric写成一个函数 (由Tester的evaluate函数调用)
参数表里只需考虑基本的参数即可,可以没有像它那么多的参数配置
support numpy array and torch tensor
"""
import warnings

import numpy as np
import torch
import sklearn.metrics as M
import warnings


def _conver_numpy(x):
'''
converte input data to numpy array
'''
if isinstance(x, np.ndarray):
"""
convert input data to numpy array
"""
if isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
elif isinstance(x, torch.Tensor):
return x.numpy()
elif isinstance(x, list):
elif isinstance(x, list):
return np.array(x)
raise TypeError('cannot accept obejct: {}'.format(x))
raise TypeError('cannot accept object: {}'.format(x))


def _check_same_len(*arrays, axis=0):
'''
"""
check if input array list has same length for one dimension
'''
"""
lens = set([x.shape[axis] for x in arrays if x is not None])
return len(lens) == 1

def _label_types(y):
'''
"""
determine the type
"binary"
"multiclass"
"multiclass-multioutput"
"multilabel"
"unknown"
'''
"""
# never squeeze the first dimension
y = np.squeeze(y, list(range(1, len(y.shape))))
shape = y.shape
if len(shape) < 1:
if len(shape) < 1:
raise ValueError('cannot accept data: {}'.format(y))
if len(shape) == 1:
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y
if len(shape) == 2:
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y
return 'unknown', y

def _check_data(y_true, y_pred):
'''
"""
check if y_true and y_pred is same type of data e.g both binary or multiclass
'''
"""
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred)
if not _check_same_len(y_true, y_pred):
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred))
@@ -70,9 +63,9 @@ def _check_data(y_true, y_pred):
type_set = set(['multiclass-multioutput', 'multilabel'])
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred
raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred))

def _weight_sum(y, normalize=True, sample_weight=None):
if normalize:
@@ -119,7 +112,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
pos_list = [y_true == i for i in labels]
pos_sum_list = [pos_i.sum() for pos_i in pos_list]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
@@ -130,6 +123,7 @@ def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
@@ -154,7 +148,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
pos_list = [y_true == i for i in labels]
pos_sum_list = [(y_pred == i).sum() for i in labels]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
@@ -165,6 +159,7 @@ def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
@@ -178,6 +173,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError


if __name__ == '__main__':
y = np.array([1,0,1,0,1,1])
print(_label_types(y))
y = np.array([1, 0, 1, 0, 1, 1])
print(_label_types(y))

+ 2
- 4
fastNLP/core/optimizer.py View File

@@ -1,5 +1,3 @@
'''
"""
use optimizer from Pytorch
'''

from torch.optim import *
"""

fastNLP/core/inference.py → fastNLP/core/predictor.py View File

@@ -7,9 +7,17 @@ from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils


def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None):
for indices in iterator:
batch_x = [data[idx] for idx in indices]
def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None):
"""Batch and Pad data, only for Inference.

:param iterator: An iterable object that returns a list of indices representing a mini-batch of samples.
:param use_cuda: bool, whether to use GPU
:param output_length: bool, whether to output the original length of the sequence before padding. (default: False)
:param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None)
:param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None)
:return:
"""
for batch_x in iterator:
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
@@ -29,11 +37,11 @@ def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_


def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
""" Pad a mini-batch of sequence samples to maximum length of this batch.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
:return batch: a padded mini-batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
@@ -42,13 +50,13 @@ def pad(batch, fill=0):
return batch


class Inference(object):
"""
This is an interface focusing on predicting output based on trained models.
class Predictor(object):
"""An interface for predicting outputs based on trained models.
It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP.
This class does not share any operations with Trainer and Tester.
Currently, Inference does not support GPU.
Currently, Predictor does not support GPU.
"""

def __init__(self, pickle_path):
@@ -60,11 +68,11 @@ class Inference(object):
self.word2index = load_pickle(self.pickle_path, "word2id.pkl")

def predict(self, network, data):
"""
Perform inference.
:param network:
:param data: two-level lists of strings
:return result: the model outputs
"""Perform inference using the trained model.
:param network: a PyTorch model
:param data: list of list of strings
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into indices
data = self.prepare_input(data)
@@ -73,9 +81,9 @@ class Inference(object):
self.mode(network, test=True)
self.batch_output.clear()

iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))
data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))

for batch_x in self.make_batch(iterator, data, use_cuda=False):
for batch_x in self.make_batch(data_iterator, use_cuda=False):

prediction = self.data_forward(network, batch_x)

@@ -90,20 +98,22 @@ class Inference(object):
network.train()

def data_forward(self, network, x):
"""Forward through network."""
raise NotImplementedError

def make_batch(self, iterator, data, use_cuda):
def make_batch(self, iterator, use_cuda):
raise NotImplementedError

def prepare_input(self, data):
"""
Transform two-level list of strings into that of index.
"""Transform two-level list of strings into that of index.
:param data:
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
:return data_index: list of list of int.
"""
assert isinstance(data, list)
data_index = []
@@ -113,10 +123,11 @@ class Inference(object):
return data_index

def prepare_output(self, data):
"""Transform list of batch outputs into strings."""
raise NotImplementedError


class SeqLabelInfer(Inference):
class SeqLabelInfer(Predictor):
"""
Inference on sequence labeling models.
"""
@@ -127,12 +138,15 @@ class SeqLabelInfer(Inference):
def data_forward(self, network, inputs):
"""
This is only for sequence labeling with CRF decoder.
:param network:
:param inputs:
:return: Tensor
:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return prediction: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
@@ -142,14 +156,14 @@ class SeqLabelInfer(Inference):
prediction = network.prediction(y, mask)
return torch.Tensor(prediction)

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=True)
def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=True)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return results: 2-D list of strings
"""Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length].
:return results: 2-D list of strings, shape [num_examples, tag_seq_length]
"""
results = []
for batch in batch_outputs:
@@ -158,7 +172,7 @@ class SeqLabelInfer(Inference):
return results


class ClassificationInfer(Inference):
class ClassificationInfer(Predictor):
"""
Inference on Classification models.
"""
@@ -171,8 +185,8 @@ class ClassificationInfer(Inference):
logits = network(x)
return logits

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5)
def make_batch(self, iterator, use_cuda):
return make_batch(iterator, use_cuda, output_length=False, min_len=5)

def prepare_output(self, batch_outputs):
"""

+ 31
- 8
fastNLP/core/tester.py View File

@@ -9,7 +9,7 @@ from fastNLP.modules import utils


class BaseTester(object):
"""docstring for Tester"""
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, test_args):
"""
@@ -62,8 +62,8 @@ class BaseTester(object):
step += 1

def prepare_input(self, data_path):
"""
Save the dev data once it is loaded. Can return directly next time.
"""Save the dev data once it is loaded. Can return directly next time.
:param data_path: str, the path to the pickle data for dev
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
"""
@@ -73,21 +73,29 @@ class BaseTester(object):
return self.save_dev_data

def mode(self, model, test):
"""Train mode or Test mode. This is for PyTorch currently.

:param model: a PyTorch model
:param test: bool, whether in test mode.
"""
Action.mode(model, test)

def data_forward(self, network, x):
"""A forward pass of the model. """
raise NotImplementedError

def evaluate(self, predict, truth):
"""Compute evaluation metrics for the model. """
raise NotImplementedError

@property
def metrics(self):
"""Return a list of metrics. """
raise NotImplementedError

def show_matrices(self):
"""
This is called by Trainer to print evaluation on dev set.
"""This is called by Trainer to print evaluation results on dev set during training.
:return print_str: str
"""
raise NotImplementedError
@@ -112,8 +120,17 @@ class SeqLabelTester(BaseTester):
self.batch_result = None

def data_forward(self, network, inputs):
"""This is only for sequence labeling with CRF decoder.

:param network: a PyTorch model
:param inputs: tuple of (x, seq_len)
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch
after padding.
seq_len: list of int, the lengths of sequences before padding.
:return y: Tensor of shape [batch_size, max_len]
"""
if not isinstance(inputs, tuple):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
raise RuntimeError("output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
@@ -127,6 +144,12 @@ class SeqLabelTester(BaseTester):
return y

def evaluate(self, predict, truth):
"""Compute metrics (or loss).

:param predict: Tensor, [batch_size, max_len, tag_size]
:param truth: Tensor, [batch_size, max_len]
:return:
"""
batch_size, max_len = predict.size(0), predict.size(1)
loss = self.model.loss(predict, truth, self.mask) / batch_size

@@ -151,7 +174,7 @@ class SeqLabelTester(BaseTester):
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True)


class ClassificationTester(BaseTester):
@@ -171,7 +194,7 @@ class ClassificationTester(BaseTester):
self.iterator = None

def make_batch(self, iterator, data, max_len=None):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len)
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len)

def data_forward(self, network, x):
"""Forward through network."""


+ 44
- 44
fastNLP/core/trainer.py View File

@@ -1,5 +1,6 @@
import _pickle
import os
import time
from datetime import timedelta
from time import time

@@ -13,10 +14,11 @@ from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.model_saver import ModelSaver

DEFAULT_QUEUE_SIZE = 300


class BaseTrainer(object):
"""Base trainer for all trainers.
Trainer receives a model and data, and then performs training.
"""Operations to train a model, including data loading, SGD, and validation.

Subclasses must implement the following abstract methods:
- define_optimizer
@@ -70,7 +72,7 @@ class BaseTrainer(object):
else:
self.model = network

data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path)
data_train = self.load_train_data(self.pickle_path)

# define tester over dev data
if self.validate:
@@ -82,33 +84,19 @@ class BaseTrainer(object):
self.define_optimizer()

# main training epochs
start = time()
start = time.time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = 1

for epoch in range(1, self.n_epochs + 1):

# turn on network training mode; prepare batch iterator
# turn on network training mode
self.mode(network, test=False)
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
step = 0
for batch_x, batch_y in self.make_batch(iterator, data_train):

prediction = self.data_forward(network, batch_x)
# prepare mini-batch iterator
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

if step % n_print == 0:
end = time()
diff = timedelta(seconds=round(end - start))
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
epoch, step, loss.data, diff))
step += 1
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch)

if self.validate:
validator.test(network)
@@ -120,27 +108,39 @@ class BaseTrainer(object):
print("[epoch {}]".format(epoch), end=" ")
print(validator.show_matrices())

def prepare_input(self, pickle_path):
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch."""
step = 0
for batch_x, batch_y in self.make_batch(data_iterator):

prediction = self.data_forward(network, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

if step % kwargs["n_print"] == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
kwargs["epoch"], step, loss.data, diff))
step += 1

def load_train_data(self, pickle_path):
"""
For task-specific processing.
:param pickle_path:
:return data_train, data_dev, data_test, embedding:
:return data_train
"""
names = [
"data_train.pkl", "data_dev.pkl",
"data_test.pkl", "embedding.pkl"]
files = []
for name in names:
file_path = os.path.join(pickle_path, name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
data = []
files.append(data)
return tuple(files)
file_path = os.path.join(pickle_path, "data_train.pkl")
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
raise RuntimeError("cannot find training data {}".format(file_path))
return data

def make_batch(self, iterator, data):
def make_batch(self, iterator):
raise NotImplementedError

def mode(self, network, test):
@@ -219,7 +219,7 @@ class ToyTrainer(BaseTrainer):
def __init__(self, training_args):
super(ToyTrainer, self).__init__(training_args)

def prepare_input(self, data_path):
def load_train_data(self, data_path):
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1
@@ -267,7 +267,7 @@ class SeqLabelTrainer(BaseTrainer):

def data_forward(self, network, inputs):
if not isinstance(inputs, tuple):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0])))
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]

@@ -303,8 +303,8 @@ class SeqLabelTrainer(BaseTrainer):
else:
return False

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda)
def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda)

def _create_validator(self, valid_args):
return SeqLabelTester(valid_args)
@@ -349,8 +349,8 @@ class ClassificationTrainer(BaseTrainer):
"""Apply gradient."""
self.optimizer.step()

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda)
def make_batch(self, iterator):
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda)

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""


+ 1
- 1
fastNLP/fastnlp.py View File

@@ -1,4 +1,4 @@
from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader



+ 3
- 0
fastNLP/loader/config_loader.py View File

@@ -91,6 +91,9 @@ class ConfigSection(object):
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

def __contains__(self, item):
return item in self.__dict__.keys()


if __name__ == "__main__":
config = ConfigLoader('configLoader', 'there is no data')


+ 1
- 1
fastNLP/loader/embed_loader.py View File

@@ -1,4 +1,4 @@
from loader.base_loader import BaseLoader
from fastNLP.loader.base_loader import BaseLoader


class EmbedLoader(BaseLoader):


+ 6
- 4
fastNLP/modules/utils.py View File

@@ -1,3 +1,9 @@
from collections import defaultdict

import numpy as np
import torch


def mask_softmax(matrix, mask):
if mask is None:
result = torch.nn.functional.softmax(matrix, dim=-1)
@@ -15,10 +21,6 @@ def seq_mask(seq_len, max_len):
"""
Codes from FudanParser. Not tested. Do not use !!!
"""
from collections import defaultdict

import numpy as np
import torch


def expand_gt(gt):


+ 2
- 2
reproduction/chinese_word_seg/cws_train.py View File

@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference
from fastNLP.core.predictor import Predictor

data_name = "pku_training.utf8"
cws_data_path = "/home/zyfeng/data/pku_training.utf8"
@@ -41,7 +41,7 @@ def infer():
infer_data = raw_data_loader.load_lines()

# Inference interface
infer = Inference(pickle_path)
infer = Predictor(pickle_path)
results = infer.predict(model, infer_data)

print(results)


+ 2
- 0
test/__init__.py View File

@@ -1 +1,3 @@
import fastNLP

__all__ = ["fastNLP"]

+ 1
- 1
test/ner_decode.py View File

@@ -3,7 +3,7 @@ import os

import torch

from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import AdvSeqLabel


+ 3
- 3
test/seq_labeling.py View File

@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.predictor import SeqLabelInfer

data_name = "people.txt"
data_path = "data_for_tests/people.txt"
@@ -112,5 +112,5 @@ def train_and_test():


if __name__ == "__main__":
train_and_test()
# infer()
# train_and_test()
infer()

+ 2
- 2
test/test_cws.py View File

@@ -10,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference
from fastNLP.core.predictor import Predictor

data_name = "pku_training.utf8"
# cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8"
@@ -51,7 +51,7 @@ def infer():
"""

# Inference interface
infer = Inference(pickle_path)
infer = Predictor(pickle_path)
results = infer.predict(model, infer_data)

print(results)


+ 3
- 1
test/text_classify.py View File

@@ -2,8 +2,10 @@
# encoding: utf-8

import os
import sys

from fastNLP.core.inference import ClassificationInfer
sys.path.append("..")
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import ClassDatasetLoader


Loading…
Cancel
Save